Initial commit
This commit is contained in:
1779
aws-iot-device-sdk-python/AWSIoTPythonSDK/MQTTLib.py
Normal file
1779
aws-iot-device-sdk-python/AWSIoTPythonSDK/MQTTLib.py
Normal file
File diff suppressed because it is too large
Load Diff
3
aws-iot-device-sdk-python/AWSIoTPythonSDK/__init__.py
Normal file
3
aws-iot-device-sdk-python/AWSIoTPythonSDK/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
__version__ = "1.4.8"
|
||||
|
||||
|
||||
BIN
aws-iot-device-sdk-python/AWSIoTPythonSDK/__init__.pyc
Normal file
BIN
aws-iot-device-sdk-python/AWSIoTPythonSDK/__init__.pyc
Normal file
Binary file not shown.
@@ -0,0 +1,466 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import json
|
||||
|
||||
|
||||
KEY_GROUP_LIST = "GGGroups"
|
||||
KEY_GROUP_ID = "GGGroupId"
|
||||
KEY_CORE_LIST = "Cores"
|
||||
KEY_CORE_ARN = "thingArn"
|
||||
KEY_CA_LIST = "CAs"
|
||||
KEY_CONNECTIVITY_INFO_LIST = "Connectivity"
|
||||
KEY_CONNECTIVITY_INFO_ID = "Id"
|
||||
KEY_HOST_ADDRESS = "HostAddress"
|
||||
KEY_PORT_NUMBER = "PortNumber"
|
||||
KEY_METADATA = "Metadata"
|
||||
|
||||
|
||||
class ConnectivityInfo(object):
|
||||
"""
|
||||
|
||||
Class the stores one set of the connectivity information.
|
||||
This is the data model for easy access to the discovery information from the discovery request function call. No
|
||||
need to call directly from user scripts.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, id, host, port, metadata):
|
||||
self._id = id
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._metadata = metadata
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
"""
|
||||
|
||||
Connectivity Information Id.
|
||||
|
||||
"""
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
"""
|
||||
|
||||
Host address.
|
||||
|
||||
"""
|
||||
return self._host
|
||||
|
||||
@property
|
||||
def port(self):
|
||||
"""
|
||||
|
||||
Port number.
|
||||
|
||||
"""
|
||||
return self._port
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
"""
|
||||
|
||||
Metadata string.
|
||||
|
||||
"""
|
||||
return self._metadata
|
||||
|
||||
|
||||
class CoreConnectivityInfo(object):
|
||||
"""
|
||||
|
||||
Class that stores the connectivity information for a Greengrass core.
|
||||
This is the data model for easy access to the discovery information from the discovery request function call. No
|
||||
need to call directly from user scripts.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, coreThingArn, groupId):
|
||||
self._core_thing_arn = coreThingArn
|
||||
self._group_id = groupId
|
||||
self._connectivity_info_dict = dict()
|
||||
|
||||
@property
|
||||
def coreThingArn(self):
|
||||
"""
|
||||
|
||||
Thing arn for this Greengrass core.
|
||||
|
||||
"""
|
||||
return self._core_thing_arn
|
||||
|
||||
@property
|
||||
def groupId(self):
|
||||
"""
|
||||
|
||||
Greengrass group id that this Greengrass core belongs to.
|
||||
|
||||
"""
|
||||
return self._group_id
|
||||
|
||||
@property
|
||||
def connectivityInfoList(self):
|
||||
"""
|
||||
|
||||
The list of connectivity information that this Greengrass core has.
|
||||
|
||||
"""
|
||||
return list(self._connectivity_info_dict.values())
|
||||
|
||||
def getConnectivityInfo(self, id):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used for quickly accessing a certain set of connectivity information by id.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myCoreConnectivityInfo.getConnectivityInfo("CoolId")
|
||||
|
||||
**Parameters**
|
||||
|
||||
*id* - The id for the desired connectivity information.
|
||||
|
||||
**Return**
|
||||
|
||||
:code:`AWSIoTPythonSDK.core.greengrass.discovery.models.ConnectivityInfo` object.
|
||||
|
||||
"""
|
||||
return self._connectivity_info_dict.get(id)
|
||||
|
||||
def appendConnectivityInfo(self, connectivityInfo):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used for adding a new set of connectivity information to the list for this Greengrass core. This is used by the
|
||||
SDK internally. No need to call directly from user scripts.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myCoreConnectivityInfo.appendConnectivityInfo(newInfo)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*connectivityInfo* - :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.ConnectivityInfo` object.
|
||||
|
||||
**Returns**
|
||||
|
||||
None
|
||||
|
||||
"""
|
||||
self._connectivity_info_dict[connectivityInfo.id] = connectivityInfo
|
||||
|
||||
|
||||
class GroupConnectivityInfo(object):
|
||||
"""
|
||||
|
||||
Class that stores the connectivity information for a specific Greengrass group.
|
||||
This is the data model for easy access to the discovery information from the discovery request function call. No
|
||||
need to call directly from user scripts.
|
||||
|
||||
"""
|
||||
def __init__(self, groupId):
|
||||
self._group_id = groupId
|
||||
self._core_connectivity_info_dict = dict()
|
||||
self._ca_list = list()
|
||||
|
||||
@property
|
||||
def groupId(self):
|
||||
"""
|
||||
|
||||
Id for this Greengrass group.
|
||||
|
||||
"""
|
||||
return self._group_id
|
||||
|
||||
@property
|
||||
def coreConnectivityInfoList(self):
|
||||
"""
|
||||
|
||||
A list of Greengrass cores
|
||||
(:code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo` object) that belong to this
|
||||
Greengrass group.
|
||||
|
||||
"""
|
||||
return list(self._core_connectivity_info_dict.values())
|
||||
|
||||
@property
|
||||
def caList(self):
|
||||
"""
|
||||
|
||||
A list of CA content strings for this Greengrass group.
|
||||
|
||||
"""
|
||||
return self._ca_list
|
||||
|
||||
def getCoreConnectivityInfo(self, coreThingArn):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to retrieve the corresponding :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo`
|
||||
object by core thing arn.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myGroupConnectivityInfo.getCoreConnectivityInfo("YourOwnArnString")
|
||||
|
||||
**Parameters**
|
||||
|
||||
coreThingArn - Thing arn for the desired Greengrass core.
|
||||
|
||||
**Returns**
|
||||
|
||||
:code:`AWSIoTPythonSDK.core.greengrass.discovery.CoreConnectivityInfo` object.
|
||||
|
||||
"""
|
||||
return self._core_connectivity_info_dict.get(coreThingArn)
|
||||
|
||||
def appendCoreConnectivityInfo(self, coreConnectivityInfo):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to append new core connectivity information to this group connectivity information. This is used by the
|
||||
SDK internally. No need to call directly from user scripts.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myGroupConnectivityInfo.appendCoreConnectivityInfo(newCoreConnectivityInfo)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*coreConnectivityInfo* - :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo` object.
|
||||
|
||||
**Returns**
|
||||
|
||||
None
|
||||
|
||||
"""
|
||||
self._core_connectivity_info_dict[coreConnectivityInfo.coreThingArn] = coreConnectivityInfo
|
||||
|
||||
def appendCa(self, ca):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to append new CA content string to this group connectivity information. This is used by the SDK internally.
|
||||
No need to call directly from user scripts.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myGroupConnectivityInfo.appendCa("CaContentString")
|
||||
|
||||
**Parameters**
|
||||
|
||||
*ca* - Group CA content string.
|
||||
|
||||
**Returns**
|
||||
|
||||
None
|
||||
|
||||
"""
|
||||
self._ca_list.append(ca)
|
||||
|
||||
|
||||
class DiscoveryInfo(object):
|
||||
"""
|
||||
|
||||
Class that stores the discovery information coming back from the discovery request.
|
||||
This is the data model for easy access to the discovery information from the discovery request function call. No
|
||||
need to call directly from user scripts.
|
||||
|
||||
"""
|
||||
def __init__(self, rawJson):
|
||||
self._raw_json = rawJson
|
||||
|
||||
@property
|
||||
def rawJson(self):
|
||||
"""
|
||||
|
||||
JSON response string that contains the discovery information. This is reserved in case users want to do
|
||||
some process by themselves.
|
||||
|
||||
"""
|
||||
return self._raw_json
|
||||
|
||||
def getAllCores(self):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to retrieve the list of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo`
|
||||
object for this discovery information. The retrieved cores could be from different Greengrass groups. This is
|
||||
designed for uses who want to iterate through all available cores at the same time, regardless of which group
|
||||
those cores are in.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myDiscoveryInfo.getAllCores()
|
||||
|
||||
**Parameters**
|
||||
|
||||
None
|
||||
|
||||
**Returns**
|
||||
|
||||
List of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivtyInfo` object.
|
||||
|
||||
"""
|
||||
groups_list = self.getAllGroups()
|
||||
core_list = list()
|
||||
|
||||
for group in groups_list:
|
||||
core_list.extend(group.coreConnectivityInfoList)
|
||||
|
||||
return core_list
|
||||
|
||||
def getAllCas(self):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to retrieve the list of :code:`(groupId, caContent)` pair for this discovery information. The retrieved
|
||||
pairs could be from different Greengrass groups. This is designed for users who want to iterate through all
|
||||
available cores/groups/CAs at the same time, regardless of which group those CAs belong to.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myDiscoveryInfo.getAllCas()
|
||||
|
||||
**Parameters**
|
||||
|
||||
None
|
||||
|
||||
**Returns**
|
||||
|
||||
List of :code:`(groupId, caContent)` string pair, where :code:`caContent` is the CA content string and
|
||||
:code:`groupId` is the group id that this CA belongs to.
|
||||
|
||||
"""
|
||||
group_list = self.getAllGroups()
|
||||
ca_list = list()
|
||||
|
||||
for group in group_list:
|
||||
for ca in group.caList:
|
||||
ca_list.append((group.groupId, ca))
|
||||
|
||||
return ca_list
|
||||
|
||||
def getAllGroups(self):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to retrieve the list of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.GroupConnectivityInfo`
|
||||
object for this discovery information. This is designed for users who want to iterate through all available
|
||||
groups that this Greengrass aware device (GGAD) belongs to.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myDiscoveryInfo.getAllGroups()
|
||||
|
||||
**Parameters**
|
||||
|
||||
None
|
||||
|
||||
**Returns**
|
||||
|
||||
List of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.GroupConnectivityInfo` object.
|
||||
|
||||
"""
|
||||
groups_dict = self.toObjectAtGroupLevel()
|
||||
return list(groups_dict.values())
|
||||
|
||||
def toObjectAtGroupLevel(self):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to get a dictionary of Greengrass group discovery information, with group id string as key and the
|
||||
corresponding :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.GroupConnectivityInfo` object as the
|
||||
value. This is designed for users who know exactly which group, which core and which set of connectivity info
|
||||
they want to use for the Greengrass aware device to connect.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Get to the targeted connectivity information for a specific core in a specific group
|
||||
groupLevelDiscoveryInfoObj = myDiscoveryInfo.toObjectAtGroupLevel()
|
||||
groupConnectivityInfoObj = groupLevelDiscoveryInfoObj.toObjectAtGroupLevel("IKnowMyGroupId")
|
||||
coreConnectivityInfoObj = groupConnectivityInfoObj.getCoreConnectivityInfo("IKnowMyCoreThingArn")
|
||||
connectivityInfo = coreConnectivityInfoObj.getConnectivityInfo("IKnowMyConnectivityInfoSetId")
|
||||
# Now retrieve the detailed information
|
||||
caList = groupConnectivityInfoObj.caList
|
||||
host = connectivityInfo.host
|
||||
port = connectivityInfo.port
|
||||
metadata = connectivityInfo.metadata
|
||||
# Actual connecting logic follows...
|
||||
|
||||
"""
|
||||
groups_object = json.loads(self._raw_json)
|
||||
groups_dict = dict()
|
||||
|
||||
for group_object in groups_object[KEY_GROUP_LIST]:
|
||||
group_info = self._decode_group_info(group_object)
|
||||
groups_dict[group_info.groupId] = group_info
|
||||
|
||||
return groups_dict
|
||||
|
||||
def _decode_group_info(self, group_object):
|
||||
group_id = group_object[KEY_GROUP_ID]
|
||||
group_info = GroupConnectivityInfo(group_id)
|
||||
|
||||
for core in group_object[KEY_CORE_LIST]:
|
||||
core_info = self._decode_core_info(core, group_id)
|
||||
group_info.appendCoreConnectivityInfo(core_info)
|
||||
|
||||
for ca in group_object[KEY_CA_LIST]:
|
||||
group_info.appendCa(ca)
|
||||
|
||||
return group_info
|
||||
|
||||
def _decode_core_info(self, core_object, group_id):
|
||||
core_info = CoreConnectivityInfo(core_object[KEY_CORE_ARN], group_id)
|
||||
|
||||
for connectivity_info_object in core_object[KEY_CONNECTIVITY_INFO_LIST]:
|
||||
connectivity_info = ConnectivityInfo(connectivity_info_object[KEY_CONNECTIVITY_INFO_ID],
|
||||
connectivity_info_object[KEY_HOST_ADDRESS],
|
||||
connectivity_info_object[KEY_PORT_NUMBER],
|
||||
connectivity_info_object.get(KEY_METADATA,''))
|
||||
core_info.appendConnectivityInfo(connectivity_info)
|
||||
|
||||
return core_info
|
||||
@@ -0,0 +1,426 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryInvalidRequestException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryUnauthorizedException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryDataNotFoundException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryThrottlingException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryTimeoutException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryFailure
|
||||
from AWSIoTPythonSDK.core.greengrass.discovery.models import DiscoveryInfo
|
||||
from AWSIoTPythonSDK.core.protocol.connection.alpn import SSLContextBuilder
|
||||
import re
|
||||
import sys
|
||||
import ssl
|
||||
import time
|
||||
import errno
|
||||
import logging
|
||||
import socket
|
||||
import platform
|
||||
if platform.system() == 'Windows':
|
||||
EAGAIN = errno.WSAEWOULDBLOCK
|
||||
else:
|
||||
EAGAIN = errno.EAGAIN
|
||||
|
||||
|
||||
class DiscoveryInfoProvider(object):
|
||||
|
||||
REQUEST_TYPE_PREFIX = "GET "
|
||||
PAYLOAD_PREFIX = "/greengrass/discover/thing/"
|
||||
PAYLOAD_SUFFIX = " HTTP/1.1\r\n" # Space in the front
|
||||
HOST_PREFIX = "Host: "
|
||||
HOST_SUFFIX = "\r\n\r\n"
|
||||
HTTP_PROTOCOL = r"HTTP/1.1 "
|
||||
CONTENT_LENGTH = r"content-length: "
|
||||
CONTENT_LENGTH_PATTERN = CONTENT_LENGTH + r"([0-9]+)\r\n"
|
||||
HTTP_RESPONSE_CODE_PATTERN = HTTP_PROTOCOL + r"([0-9]+) "
|
||||
|
||||
HTTP_SC_200 = "200"
|
||||
HTTP_SC_400 = "400"
|
||||
HTTP_SC_401 = "401"
|
||||
HTTP_SC_404 = "404"
|
||||
HTTP_SC_429 = "429"
|
||||
|
||||
LOW_LEVEL_RC_COMPLETE = 0
|
||||
LOW_LEVEL_RC_TIMEOUT = -1
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, caPath="", certPath="", keyPath="", host="", port=8443, timeoutSec=120):
|
||||
"""
|
||||
|
||||
The class that provides functionality to perform a Greengrass discovery process to the cloud.
|
||||
|
||||
Users can perform Greengrass discovery process for a specific Greengrass aware device to retrieve
|
||||
connectivity/identity information of Greengrass cores within the same group.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
from AWSIoTPythonSDK.core.greengrass.discovery.providers import DiscoveryInfoProvider
|
||||
|
||||
# Create a discovery information provider
|
||||
myDiscoveryInfoProvider = DiscoveryInfoProvider()
|
||||
# Create a discovery information provider with custom configuration
|
||||
myDiscoveryInfoProvider = DiscoveryInfoProvider(caPath=myCAPath, certPath=myCertPath, keyPath=myKeyPath, host=myHost, timeoutSec=myTimeoutSec)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*caPath* - Path to read the root CA file.
|
||||
|
||||
*certPath* - Path to read the certificate file.
|
||||
|
||||
*keyPath* - Path to read the private key file.
|
||||
|
||||
*host* - String that denotes the host name of the user-specific AWS IoT endpoint.
|
||||
|
||||
*port* - Integer that denotes the port number to connect to. For discovery purpose, it is 8443 by default.
|
||||
|
||||
*timeoutSec* - Time out configuration in seconds to consider a discovery request sending/response waiting has
|
||||
been timed out.
|
||||
|
||||
**Returns**
|
||||
|
||||
AWSIoTPythonSDK.core.greengrass.discovery.providers.DiscoveryInfoProvider object
|
||||
|
||||
"""
|
||||
self._ca_path = caPath
|
||||
self._cert_path = certPath
|
||||
self._key_path = keyPath
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._timeout_sec = timeoutSec
|
||||
self._expected_exception_map = {
|
||||
self.HTTP_SC_400 : DiscoveryInvalidRequestException(),
|
||||
self.HTTP_SC_401 : DiscoveryUnauthorizedException(),
|
||||
self.HTTP_SC_404 : DiscoveryDataNotFoundException(),
|
||||
self.HTTP_SC_429 : DiscoveryThrottlingException()
|
||||
}
|
||||
|
||||
def configureEndpoint(self, host, port=8443):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to configure the host address and port number for the discovery request to hit. Should be called before
|
||||
the discovery request happens.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Using default port configuration, 8443
|
||||
myDiscoveryInfoProvider.configureEndpoint(host="prefix.iot.us-east-1.amazonaws.com")
|
||||
# Customize port configuration
|
||||
myDiscoveryInfoProvider.configureEndpoint(host="prefix.iot.us-east-1.amazonaws.com", port=8888)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*host* - String that denotes the host name of the user-specific AWS IoT endpoint.
|
||||
|
||||
*port* - Integer that denotes the port number to connect to. For discovery purpose, it is 8443 by default.
|
||||
|
||||
**Returns**
|
||||
|
||||
None
|
||||
|
||||
"""
|
||||
self._host = host
|
||||
self._port = port
|
||||
|
||||
def configureCredentials(self, caPath, certPath, keyPath):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to configure the credentials for discovery request. Should be called before the discovery request happens.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myDiscoveryInfoProvider.configureCredentials("my/ca/path", "my/cert/path", "my/key/path")
|
||||
|
||||
**Parameters**
|
||||
|
||||
*caPath* - Path to read the root CA file.
|
||||
|
||||
*certPath* - Path to read the certificate file.
|
||||
|
||||
*keyPath* - Path to read the private key file.
|
||||
|
||||
**Returns**
|
||||
|
||||
None
|
||||
|
||||
"""
|
||||
self._ca_path = caPath
|
||||
self._cert_path = certPath
|
||||
self._key_path = keyPath
|
||||
|
||||
def configureTimeout(self, timeoutSec):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to configure the time out in seconds for discovery request sending/response waiting. Should be called before
|
||||
the discovery request happens.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Configure the time out for discovery to be 10 seconds
|
||||
myDiscoveryInfoProvider.configureTimeout(10)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*timeoutSec* - Time out configuration in seconds to consider a discovery request sending/response waiting has
|
||||
been timed out.
|
||||
|
||||
**Returns**
|
||||
|
||||
None
|
||||
|
||||
"""
|
||||
self._timeout_sec = timeoutSec
|
||||
|
||||
def discover(self, thingName):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Perform the discovery request for the given Greengrass aware device thing name.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myDiscoveryInfoProvider.discover(thingName="myGGAD")
|
||||
|
||||
**Parameters**
|
||||
|
||||
*thingName* - Greengrass aware device thing name.
|
||||
|
||||
**Returns**
|
||||
|
||||
:code:`AWSIoTPythonSDK.core.greengrass.discovery.models.DiscoveryInfo` object.
|
||||
|
||||
"""
|
||||
self._logger.info("Starting discover request...")
|
||||
self._logger.info("Endpoint: " + self._host + ":" + str(self._port))
|
||||
self._logger.info("Target thing: " + thingName)
|
||||
sock = self._create_tcp_connection()
|
||||
ssl_sock = self._create_ssl_connection(sock)
|
||||
self._raise_on_timeout(self._send_discovery_request(ssl_sock, thingName))
|
||||
status_code, response_body = self._receive_discovery_response(ssl_sock)
|
||||
|
||||
return self._raise_if_not_200(status_code, response_body)
|
||||
|
||||
def _create_tcp_connection(self):
|
||||
self._logger.debug("Creating tcp connection...")
|
||||
try:
|
||||
if (sys.version_info[0] == 2 and sys.version_info[1] < 7) or (sys.version_info[0] == 3 and sys.version_info[1] < 2):
|
||||
sock = socket.create_connection((self._host, self._port))
|
||||
else:
|
||||
sock = socket.create_connection((self._host, self._port), source_address=("", 0))
|
||||
return sock
|
||||
except socket.error as err:
|
||||
if err.errno != errno.EINPROGRESS and err.errno != errno.EWOULDBLOCK and err.errno != EAGAIN:
|
||||
raise
|
||||
self._logger.debug("Created tcp connection.")
|
||||
|
||||
def _create_ssl_connection(self, sock):
|
||||
self._logger.debug("Creating ssl connection...")
|
||||
|
||||
ssl_protocol_version = ssl.PROTOCOL_SSLv23
|
||||
|
||||
if self._port == 443:
|
||||
ssl_context = SSLContextBuilder()\
|
||||
.with_ca_certs(self._ca_path)\
|
||||
.with_cert_key_pair(self._cert_path, self._key_path)\
|
||||
.with_cert_reqs(ssl.CERT_REQUIRED)\
|
||||
.with_check_hostname(True)\
|
||||
.with_ciphers(None)\
|
||||
.with_alpn_protocols(['x-amzn-http-ca'])\
|
||||
.build()
|
||||
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=self._host, do_handshake_on_connect=False)
|
||||
ssl_sock.do_handshake()
|
||||
else:
|
||||
ssl_sock = ssl.wrap_socket(sock,
|
||||
certfile=self._cert_path,
|
||||
keyfile=self._key_path,
|
||||
ca_certs=self._ca_path,
|
||||
cert_reqs=ssl.CERT_REQUIRED,
|
||||
ssl_version=ssl_protocol_version)
|
||||
|
||||
self._logger.debug("Matching host name...")
|
||||
if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and sys.version_info[1] < 2):
|
||||
self._tls_match_hostname(ssl_sock)
|
||||
else:
|
||||
ssl.match_hostname(ssl_sock.getpeercert(), self._host)
|
||||
|
||||
return ssl_sock
|
||||
|
||||
def _tls_match_hostname(self, ssl_sock):
|
||||
try:
|
||||
cert = ssl_sock.getpeercert()
|
||||
except AttributeError:
|
||||
# the getpeercert can throw Attribute error: object has no attribute 'peer_certificate'
|
||||
# Don't let that crash the whole client. See also: http://bugs.python.org/issue13721
|
||||
raise ssl.SSLError('Not connected')
|
||||
|
||||
san = cert.get('subjectAltName')
|
||||
if san:
|
||||
have_san_dns = False
|
||||
for (key, value) in san:
|
||||
if key == 'DNS':
|
||||
have_san_dns = True
|
||||
if self._host_matches_cert(self._host.lower(), value.lower()) == True:
|
||||
return
|
||||
if key == 'IP Address':
|
||||
have_san_dns = True
|
||||
if value.lower() == self._host.lower():
|
||||
return
|
||||
|
||||
if have_san_dns:
|
||||
# Only check subject if subjectAltName dns not found.
|
||||
raise ssl.SSLError('Certificate subject does not match remote hostname.')
|
||||
subject = cert.get('subject')
|
||||
if subject:
|
||||
for ((key, value),) in subject:
|
||||
if key == 'commonName':
|
||||
if self._host_matches_cert(self._host.lower(), value.lower()) == True:
|
||||
return
|
||||
|
||||
raise ssl.SSLError('Certificate subject does not match remote hostname.')
|
||||
|
||||
def _host_matches_cert(self, host, cert_host):
|
||||
if cert_host[0:2] == "*.":
|
||||
if cert_host.count("*") != 1:
|
||||
return False
|
||||
|
||||
host_match = host.split(".", 1)[1]
|
||||
cert_match = cert_host.split(".", 1)[1]
|
||||
if host_match == cert_match:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
if host == cert_host:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _send_discovery_request(self, ssl_sock, thing_name):
|
||||
request = self.REQUEST_TYPE_PREFIX + \
|
||||
self.PAYLOAD_PREFIX + \
|
||||
thing_name + \
|
||||
self.PAYLOAD_SUFFIX + \
|
||||
self.HOST_PREFIX + \
|
||||
self._host + ":" + str(self._port) + \
|
||||
self.HOST_SUFFIX
|
||||
self._logger.debug("Sending discover request: " + request)
|
||||
|
||||
start_time = time.time()
|
||||
desired_length_to_write = len(request)
|
||||
actual_length_written = 0
|
||||
while True:
|
||||
try:
|
||||
length_written = ssl_sock.write(request.encode("utf-8"))
|
||||
actual_length_written += length_written
|
||||
except socket.error as err:
|
||||
if err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE:
|
||||
pass
|
||||
if actual_length_written == desired_length_to_write:
|
||||
return self.LOW_LEVEL_RC_COMPLETE
|
||||
if start_time + self._timeout_sec < time.time():
|
||||
return self.LOW_LEVEL_RC_TIMEOUT
|
||||
|
||||
def _receive_discovery_response(self, ssl_sock):
|
||||
self._logger.debug("Receiving discover response header...")
|
||||
rc1, response_header = self._receive_until(ssl_sock, self._got_two_crlfs)
|
||||
status_code, body_length = self._handle_discovery_response_header(rc1, response_header.decode("utf-8"))
|
||||
|
||||
self._logger.debug("Receiving discover response body...")
|
||||
rc2, response_body = self._receive_until(ssl_sock, self._got_enough_bytes, body_length)
|
||||
response_body = self._handle_discovery_response_body(rc2, response_body.decode("utf-8"))
|
||||
|
||||
return status_code, response_body
|
||||
|
||||
def _receive_until(self, ssl_sock, criteria_function, extra_data=None):
|
||||
start_time = time.time()
|
||||
response = bytearray()
|
||||
number_bytes_read = 0
|
||||
while True: # Python does not have do-while
|
||||
try:
|
||||
response.append(self._convert_to_int_py3(ssl_sock.read(1)))
|
||||
number_bytes_read += 1
|
||||
except socket.error as err:
|
||||
if err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE:
|
||||
pass
|
||||
|
||||
if criteria_function((number_bytes_read, response, extra_data)):
|
||||
return self.LOW_LEVEL_RC_COMPLETE, response
|
||||
if start_time + self._timeout_sec < time.time():
|
||||
return self.LOW_LEVEL_RC_TIMEOUT, response
|
||||
|
||||
def _convert_to_int_py3(self, input_char):
|
||||
try:
|
||||
return ord(input_char)
|
||||
except:
|
||||
return input_char
|
||||
|
||||
def _got_enough_bytes(self, data):
|
||||
number_bytes_read, response, target_length = data
|
||||
return number_bytes_read == int(target_length)
|
||||
|
||||
def _got_two_crlfs(self, data):
|
||||
number_bytes_read, response, extra_data_unused = data
|
||||
number_of_crlf = 2
|
||||
has_enough_bytes = number_bytes_read > number_of_crlf * 2 - 1
|
||||
if has_enough_bytes:
|
||||
end_of_received = response[number_bytes_read - number_of_crlf * 2 : number_bytes_read]
|
||||
expected_end_of_response = b"\r\n" * number_of_crlf
|
||||
return end_of_received == expected_end_of_response
|
||||
else:
|
||||
return False
|
||||
|
||||
def _handle_discovery_response_header(self, rc, response):
|
||||
self._raise_on_timeout(rc)
|
||||
http_status_code_matcher = re.compile(self.HTTP_RESPONSE_CODE_PATTERN)
|
||||
http_status_code_matched_groups = http_status_code_matcher.match(response)
|
||||
content_length_matcher = re.compile(self.CONTENT_LENGTH_PATTERN)
|
||||
content_length_matched_groups = content_length_matcher.search(response)
|
||||
return http_status_code_matched_groups.group(1), content_length_matched_groups.group(1)
|
||||
|
||||
def _handle_discovery_response_body(self, rc, response):
|
||||
self._raise_on_timeout(rc)
|
||||
return response
|
||||
|
||||
def _raise_on_timeout(self, rc):
|
||||
if rc == self.LOW_LEVEL_RC_TIMEOUT:
|
||||
raise DiscoveryTimeoutException()
|
||||
|
||||
def _raise_if_not_200(self, status_code, response_body): # response_body here is str in Py3
|
||||
if status_code != self.HTTP_SC_200:
|
||||
expected_exception = self._expected_exception_map.get(status_code)
|
||||
if expected_exception:
|
||||
raise expected_exception
|
||||
else:
|
||||
raise DiscoveryFailure(response_body)
|
||||
return DiscoveryInfo(response_body)
|
||||
@@ -0,0 +1,156 @@
|
||||
# /*
|
||||
# * Copyright 2010-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import json
|
||||
|
||||
_BASE_THINGS_TOPIC = "$aws/things/"
|
||||
_NOTIFY_OPERATION = "notify"
|
||||
_NOTIFY_NEXT_OPERATION = "notify-next"
|
||||
_GET_OPERATION = "get"
|
||||
_START_NEXT_OPERATION = "start-next"
|
||||
_WILDCARD_OPERATION = "+"
|
||||
_UPDATE_OPERATION = "update"
|
||||
_ACCEPTED_REPLY = "accepted"
|
||||
_REJECTED_REPLY = "rejected"
|
||||
_WILDCARD_REPLY = "#"
|
||||
|
||||
#Members of this enum are tuples
|
||||
_JOB_ID_REQUIRED_INDEX = 1
|
||||
_JOB_OPERATION_INDEX = 2
|
||||
|
||||
_STATUS_KEY = 'status'
|
||||
_STATUS_DETAILS_KEY = 'statusDetails'
|
||||
_EXPECTED_VERSION_KEY = 'expectedVersion'
|
||||
_EXEXCUTION_NUMBER_KEY = 'executionNumber'
|
||||
_INCLUDE_JOB_EXECUTION_STATE_KEY = 'includeJobExecutionState'
|
||||
_INCLUDE_JOB_DOCUMENT_KEY = 'includeJobDocument'
|
||||
_CLIENT_TOKEN_KEY = 'clientToken'
|
||||
_STEP_TIMEOUT_IN_MINUTES_KEY = 'stepTimeoutInMinutes'
|
||||
|
||||
#The type of job topic.
|
||||
class jobExecutionTopicType(object):
|
||||
JOB_UNRECOGNIZED_TOPIC = (0, False, '')
|
||||
JOB_GET_PENDING_TOPIC = (1, False, _GET_OPERATION)
|
||||
JOB_START_NEXT_TOPIC = (2, False, _START_NEXT_OPERATION)
|
||||
JOB_DESCRIBE_TOPIC = (3, True, _GET_OPERATION)
|
||||
JOB_UPDATE_TOPIC = (4, True, _UPDATE_OPERATION)
|
||||
JOB_NOTIFY_TOPIC = (5, False, _NOTIFY_OPERATION)
|
||||
JOB_NOTIFY_NEXT_TOPIC = (6, False, _NOTIFY_NEXT_OPERATION)
|
||||
JOB_WILDCARD_TOPIC = (7, False, _WILDCARD_OPERATION)
|
||||
|
||||
#Members of this enum are tuples
|
||||
_JOB_SUFFIX_INDEX = 1
|
||||
#The type of reply topic, or #JOB_REQUEST_TYPE for topics that are not replies.
|
||||
class jobExecutionTopicReplyType(object):
|
||||
JOB_UNRECOGNIZED_TOPIC_TYPE = (0, '')
|
||||
JOB_REQUEST_TYPE = (1, '')
|
||||
JOB_ACCEPTED_REPLY_TYPE = (2, '/' + _ACCEPTED_REPLY)
|
||||
JOB_REJECTED_REPLY_TYPE = (3, '/' + _REJECTED_REPLY)
|
||||
JOB_WILDCARD_REPLY_TYPE = (4, '/' + _WILDCARD_REPLY)
|
||||
|
||||
_JOB_STATUS_INDEX = 1
|
||||
class jobExecutionStatus(object):
|
||||
JOB_EXECUTION_STATUS_NOT_SET = (0, None)
|
||||
JOB_EXECUTION_QUEUED = (1, 'QUEUED')
|
||||
JOB_EXECUTION_IN_PROGRESS = (2, 'IN_PROGRESS')
|
||||
JOB_EXECUTION_FAILED = (3, 'FAILED')
|
||||
JOB_EXECUTION_SUCCEEDED = (4, 'SUCCEEDED')
|
||||
JOB_EXECUTION_CANCELED = (5, 'CANCELED')
|
||||
JOB_EXECUTION_REJECTED = (6, 'REJECTED')
|
||||
JOB_EXECUTION_UNKNOWN_STATUS = (99, None)
|
||||
|
||||
def _getExecutionStatus(jobStatus):
|
||||
try:
|
||||
return jobStatus[_JOB_STATUS_INDEX]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def _isWithoutJobIdTopicType(srcJobExecTopicType):
|
||||
return (srcJobExecTopicType == jobExecutionTopicType.JOB_GET_PENDING_TOPIC or srcJobExecTopicType == jobExecutionTopicType.JOB_START_NEXT_TOPIC
|
||||
or srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_TOPIC or srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC)
|
||||
|
||||
class thingJobManager:
|
||||
def __init__(self, thingName, clientToken = None):
|
||||
self._thingName = thingName
|
||||
self._clientToken = clientToken
|
||||
|
||||
def getJobTopic(self, srcJobExecTopicType, srcJobExecTopicReplyType=jobExecutionTopicReplyType.JOB_REQUEST_TYPE, jobId=None):
|
||||
if self._thingName is None:
|
||||
return None
|
||||
|
||||
#Verify topics that only support request type, actually have request type specified for reply
|
||||
if (srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_TOPIC or srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) and srcJobExecTopicReplyType != jobExecutionTopicReplyType.JOB_REQUEST_TYPE:
|
||||
return None
|
||||
|
||||
#Verify topics that explicitly do not want a job ID do not have one specified
|
||||
if (jobId is not None and _isWithoutJobIdTopicType(srcJobExecTopicType)):
|
||||
return None
|
||||
|
||||
#Verify job ID is present if the topic requires one
|
||||
if jobId is None and srcJobExecTopicType[_JOB_ID_REQUIRED_INDEX]:
|
||||
return None
|
||||
|
||||
#Ensure the job operation is a non-empty string
|
||||
if srcJobExecTopicType[_JOB_OPERATION_INDEX] == '':
|
||||
return None
|
||||
|
||||
if srcJobExecTopicType[_JOB_ID_REQUIRED_INDEX]:
|
||||
return '{0}{1}/jobs/{2}/{3}{4}'.format(_BASE_THINGS_TOPIC, self._thingName, str(jobId), srcJobExecTopicType[_JOB_OPERATION_INDEX], srcJobExecTopicReplyType[_JOB_SUFFIX_INDEX])
|
||||
elif srcJobExecTopicType == jobExecutionTopicType.JOB_WILDCARD_TOPIC:
|
||||
return '{0}{1}/jobs/#'.format(_BASE_THINGS_TOPIC, self._thingName)
|
||||
else:
|
||||
return '{0}{1}/jobs/{2}{3}'.format(_BASE_THINGS_TOPIC, self._thingName, srcJobExecTopicType[_JOB_OPERATION_INDEX], srcJobExecTopicReplyType[_JOB_SUFFIX_INDEX])
|
||||
|
||||
def serializeJobExecutionUpdatePayload(self, status, statusDetails=None, expectedVersion=0, executionNumber=0, includeJobExecutionState=False, includeJobDocument=False, stepTimeoutInMinutes=None):
|
||||
executionStatus = _getExecutionStatus(status)
|
||||
if executionStatus is None:
|
||||
return None
|
||||
payload = {_STATUS_KEY: executionStatus}
|
||||
if statusDetails:
|
||||
payload[_STATUS_DETAILS_KEY] = statusDetails
|
||||
if expectedVersion > 0:
|
||||
payload[_EXPECTED_VERSION_KEY] = str(expectedVersion)
|
||||
if executionNumber > 0:
|
||||
payload[_EXEXCUTION_NUMBER_KEY] = str(executionNumber)
|
||||
if includeJobExecutionState:
|
||||
payload[_INCLUDE_JOB_EXECUTION_STATE_KEY] = True
|
||||
if includeJobDocument:
|
||||
payload[_INCLUDE_JOB_DOCUMENT_KEY] = True
|
||||
if self._clientToken is not None:
|
||||
payload[_CLIENT_TOKEN_KEY] = self._clientToken
|
||||
if stepTimeoutInMinutes is not None:
|
||||
payload[_STEP_TIMEOUT_IN_MINUTES_KEY] = stepTimeoutInMinutes
|
||||
return json.dumps(payload)
|
||||
|
||||
def serializeDescribeJobExecutionPayload(self, executionNumber=0, includeJobDocument=True):
|
||||
payload = {_INCLUDE_JOB_DOCUMENT_KEY: includeJobDocument}
|
||||
if executionNumber > 0:
|
||||
payload[_EXEXCUTION_NUMBER_KEY] = executionNumber
|
||||
if self._clientToken is not None:
|
||||
payload[_CLIENT_TOKEN_KEY] = self._clientToken
|
||||
return json.dumps(payload)
|
||||
|
||||
def serializeStartNextPendingJobExecutionPayload(self, statusDetails=None, stepTimeoutInMinutes=None):
|
||||
payload = {}
|
||||
if self._clientToken is not None:
|
||||
payload[_CLIENT_TOKEN_KEY] = self._clientToken
|
||||
if statusDetails is not None:
|
||||
payload[_STATUS_DETAILS_KEY] = statusDetails
|
||||
if stepTimeoutInMinutes is not None:
|
||||
payload[_STEP_TIMEOUT_IN_MINUTES_KEY] = stepTimeoutInMinutes
|
||||
return json.dumps(payload)
|
||||
|
||||
def serializeClientTokenPayload(self):
|
||||
return json.dumps({_CLIENT_TOKEN_KEY: self._clientToken}) if self._clientToken is not None else '{}'
|
||||
@@ -0,0 +1,63 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except:
|
||||
ssl = None
|
||||
|
||||
|
||||
class SSLContextBuilder(object):
|
||||
|
||||
def __init__(self):
|
||||
self.check_supportability()
|
||||
self._ssl_context = ssl.create_default_context()
|
||||
|
||||
def check_supportability(self):
|
||||
if ssl is None:
|
||||
raise RuntimeError("This platform has no SSL/TLS.")
|
||||
if not hasattr(ssl, "SSLContext"):
|
||||
raise NotImplementedError("This platform does not support SSLContext. Python 2.7.10+/3.5+ is required.")
|
||||
if not hasattr(ssl.SSLContext, "set_alpn_protocols"):
|
||||
raise NotImplementedError("This platform does not support ALPN as TLS extensions. Python 2.7.10+/3.5+ is required.")
|
||||
|
||||
def with_ca_certs(self, ca_certs):
|
||||
self._ssl_context.load_verify_locations(ca_certs)
|
||||
return self
|
||||
|
||||
def with_cert_key_pair(self, cert_file, key_file):
|
||||
self._ssl_context.load_cert_chain(cert_file, key_file)
|
||||
return self
|
||||
|
||||
def with_cert_reqs(self, cert_reqs):
|
||||
self._ssl_context.verify_mode = cert_reqs
|
||||
return self
|
||||
|
||||
def with_check_hostname(self, check_hostname):
|
||||
self._ssl_context.check_hostname = check_hostname
|
||||
return self
|
||||
|
||||
def with_ciphers(self, ciphers):
|
||||
if ciphers is not None:
|
||||
self._ssl_context.set_ciphers(ciphers) # set_ciphers() does not allow None input. Use default (do nothing) if None
|
||||
return self
|
||||
|
||||
def with_alpn_protocols(self, alpn_protocols):
|
||||
self._ssl_context.set_alpn_protocols(alpn_protocols)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self._ssl_context
|
||||
@@ -0,0 +1,699 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
# This class implements the progressive backoff logic for auto-reconnect.
|
||||
# It manages the reconnect wait time for the current reconnect, controling
|
||||
# when to increase it and when to reset it.
|
||||
|
||||
|
||||
import re
|
||||
import sys
|
||||
import ssl
|
||||
import errno
|
||||
import struct
|
||||
import socket
|
||||
import base64
|
||||
import time
|
||||
import threading
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import ClientError
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import wssNoKeyInEnvironmentError
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import wssHandShakeError
|
||||
from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC
|
||||
try:
|
||||
from urllib.parse import quote # Python 3+
|
||||
except ImportError:
|
||||
from urllib import quote
|
||||
# INI config file handling
|
||||
try:
|
||||
from configparser import ConfigParser # Python 3+
|
||||
from configparser import NoOptionError
|
||||
from configparser import NoSectionError
|
||||
except ImportError:
|
||||
from ConfigParser import ConfigParser
|
||||
from ConfigParser import NoOptionError
|
||||
from ConfigParser import NoSectionError
|
||||
|
||||
|
||||
class ProgressiveBackOffCore:
|
||||
# Logger
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, srcBaseReconnectTimeSecond=1, srcMaximumReconnectTimeSecond=32, srcMinimumConnectTimeSecond=20):
|
||||
# The base reconnection time in seconds, default 1
|
||||
self._baseReconnectTimeSecond = srcBaseReconnectTimeSecond
|
||||
# The maximum reconnection time in seconds, default 32
|
||||
self._maximumReconnectTimeSecond = srcMaximumReconnectTimeSecond
|
||||
# The minimum time in milliseconds that a connection must be maintained in order to be considered stable
|
||||
# Default 20
|
||||
self._minimumConnectTimeSecond = srcMinimumConnectTimeSecond
|
||||
# Current backOff time in seconds, init to equal to 0
|
||||
self._currentBackoffTimeSecond = 1
|
||||
# Handler for timer
|
||||
self._resetBackoffTimer = None
|
||||
|
||||
# For custom progressiveBackoff timing configuration
|
||||
def configTime(self, srcBaseReconnectTimeSecond, srcMaximumReconnectTimeSecond, srcMinimumConnectTimeSecond):
|
||||
if srcBaseReconnectTimeSecond < 0 or srcMaximumReconnectTimeSecond < 0 or srcMinimumConnectTimeSecond < 0:
|
||||
self._logger.error("init: Negative time configuration detected.")
|
||||
raise ValueError("Negative time configuration detected.")
|
||||
if srcBaseReconnectTimeSecond >= srcMinimumConnectTimeSecond:
|
||||
self._logger.error("init: Min connect time should be bigger than base reconnect time.")
|
||||
raise ValueError("Min connect time should be bigger than base reconnect time.")
|
||||
self._baseReconnectTimeSecond = srcBaseReconnectTimeSecond
|
||||
self._maximumReconnectTimeSecond = srcMaximumReconnectTimeSecond
|
||||
self._minimumConnectTimeSecond = srcMinimumConnectTimeSecond
|
||||
self._currentBackoffTimeSecond = 1
|
||||
|
||||
# Block the reconnect logic for _currentBackoffTimeSecond
|
||||
# Update the currentBackoffTimeSecond for the next reconnect
|
||||
# Cancel the in-waiting timer for resetting backOff time
|
||||
# This should get called only when a disconnect/reconnect happens
|
||||
def backOff(self):
|
||||
self._logger.debug("backOff: current backoff time is: " + str(self._currentBackoffTimeSecond) + " sec.")
|
||||
if self._resetBackoffTimer is not None:
|
||||
# Cancel the timer
|
||||
self._resetBackoffTimer.cancel()
|
||||
# Block the reconnect logic
|
||||
time.sleep(self._currentBackoffTimeSecond)
|
||||
# Update the backoff time
|
||||
if self._currentBackoffTimeSecond == 0:
|
||||
# This is the first attempt to connect, set it to base
|
||||
self._currentBackoffTimeSecond = self._baseReconnectTimeSecond
|
||||
else:
|
||||
# r_cur = min(2^n*r_base, r_max)
|
||||
self._currentBackoffTimeSecond = min(self._maximumReconnectTimeSecond, self._currentBackoffTimeSecond * 2)
|
||||
|
||||
# Start the timer for resetting _currentBackoffTimeSecond
|
||||
# Will be cancelled upon calling backOff
|
||||
def startStableConnectionTimer(self):
|
||||
self._resetBackoffTimer = threading.Timer(self._minimumConnectTimeSecond,
|
||||
self._connectionStableThenResetBackoffTime)
|
||||
self._resetBackoffTimer.start()
|
||||
|
||||
def stopStableConnectionTimer(self):
|
||||
if self._resetBackoffTimer is not None:
|
||||
# Cancel the timer
|
||||
self._resetBackoffTimer.cancel()
|
||||
|
||||
# Timer callback to reset _currentBackoffTimeSecond
|
||||
# If the connection is stable for longer than _minimumConnectTimeSecond,
|
||||
# reset the currentBackoffTimeSecond to _baseReconnectTimeSecond
|
||||
def _connectionStableThenResetBackoffTime(self):
|
||||
self._logger.debug(
|
||||
"stableConnection: Resetting the backoff time to: " + str(self._baseReconnectTimeSecond) + " sec.")
|
||||
self._currentBackoffTimeSecond = self._baseReconnectTimeSecond
|
||||
|
||||
|
||||
class SigV4Core:
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self):
|
||||
self._aws_access_key_id = ""
|
||||
self._aws_secret_access_key = ""
|
||||
self._aws_session_token = ""
|
||||
self._credentialConfigFilePath = "~/.aws/credentials"
|
||||
|
||||
def setIAMCredentials(self, srcAWSAccessKeyID, srcAWSSecretAccessKey, srcAWSSessionToken):
|
||||
self._aws_access_key_id = srcAWSAccessKeyID
|
||||
self._aws_secret_access_key = srcAWSSecretAccessKey
|
||||
self._aws_session_token = srcAWSSessionToken
|
||||
|
||||
def _createAmazonDate(self):
|
||||
# Returned as a unicode string in Py3.x
|
||||
amazonDate = []
|
||||
currentTime = datetime.utcnow()
|
||||
YMDHMS = currentTime.strftime('%Y%m%dT%H%M%SZ')
|
||||
YMD = YMDHMS[0:YMDHMS.index('T')]
|
||||
amazonDate.append(YMD)
|
||||
amazonDate.append(YMDHMS)
|
||||
return amazonDate
|
||||
|
||||
def _sign(self, key, message):
|
||||
# Returned as a utf-8 byte string in Py3.x
|
||||
return hmac.new(key, message.encode('utf-8'), hashlib.sha256).digest()
|
||||
|
||||
def _getSignatureKey(self, key, dateStamp, regionName, serviceName):
|
||||
# Returned as a utf-8 byte string in Py3.x
|
||||
kDate = self._sign(('AWS4' + key).encode('utf-8'), dateStamp)
|
||||
kRegion = self._sign(kDate, regionName)
|
||||
kService = self._sign(kRegion, serviceName)
|
||||
kSigning = self._sign(kService, 'aws4_request')
|
||||
return kSigning
|
||||
|
||||
def _checkIAMCredentials(self):
|
||||
# Check custom config
|
||||
ret = self._checkKeyInCustomConfig()
|
||||
# Check environment variables
|
||||
if not ret:
|
||||
ret = self._checkKeyInEnv()
|
||||
# Check files
|
||||
if not ret:
|
||||
ret = self._checkKeyInFiles()
|
||||
# All credentials returned as unicode strings in Py3.x
|
||||
return ret
|
||||
|
||||
def _checkKeyInEnv(self):
|
||||
ret = dict()
|
||||
self._aws_access_key_id = os.environ.get('AWS_ACCESS_KEY_ID')
|
||||
self._aws_secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY')
|
||||
self._aws_session_token = os.environ.get('AWS_SESSION_TOKEN')
|
||||
if self._aws_access_key_id is not None and self._aws_secret_access_key is not None:
|
||||
ret["aws_access_key_id"] = self._aws_access_key_id
|
||||
ret["aws_secret_access_key"] = self._aws_secret_access_key
|
||||
# We do not necessarily need session token...
|
||||
if self._aws_session_token is not None:
|
||||
ret["aws_session_token"] = self._aws_session_token
|
||||
self._logger.debug("IAM credentials from env var.")
|
||||
return ret
|
||||
|
||||
def _checkKeyInINIDefault(self, srcConfigParser, sectionName):
|
||||
ret = dict()
|
||||
# Check aws_access_key_id and aws_secret_access_key
|
||||
try:
|
||||
ret["aws_access_key_id"] = srcConfigParser.get(sectionName, "aws_access_key_id")
|
||||
ret["aws_secret_access_key"] = srcConfigParser.get(sectionName, "aws_secret_access_key")
|
||||
except NoOptionError:
|
||||
self._logger.warn("Cannot find IAM keyID/secretKey in credential file.")
|
||||
# We do not continue searching if we cannot even get IAM id/secret right
|
||||
if len(ret) == 2:
|
||||
# Check aws_session_token, optional
|
||||
try:
|
||||
ret["aws_session_token"] = srcConfigParser.get(sectionName, "aws_session_token")
|
||||
except NoOptionError:
|
||||
self._logger.debug("No AWS Session Token found.")
|
||||
return ret
|
||||
|
||||
def _checkKeyInFiles(self):
|
||||
credentialFile = None
|
||||
credentialConfig = None
|
||||
ret = dict()
|
||||
# Should be compatible with aws cli default credential configuration
|
||||
# *NIX/Windows
|
||||
try:
|
||||
# See if we get the file
|
||||
credentialConfig = ConfigParser()
|
||||
credentialFilePath = os.path.expanduser(self._credentialConfigFilePath) # Is it compatible with windows? \/
|
||||
credentialConfig.read(credentialFilePath)
|
||||
# Now we have the file, start looking for credentials...
|
||||
# 'default' section
|
||||
ret = self._checkKeyInINIDefault(credentialConfig, "default")
|
||||
if not ret:
|
||||
# 'DEFAULT' section
|
||||
ret = self._checkKeyInINIDefault(credentialConfig, "DEFAULT")
|
||||
self._logger.debug("IAM credentials from file.")
|
||||
except IOError:
|
||||
self._logger.debug("No IAM credential configuration file in " + credentialFilePath)
|
||||
except NoSectionError:
|
||||
self._logger.error("Cannot find IAM 'default' section.")
|
||||
return ret
|
||||
|
||||
def _checkKeyInCustomConfig(self):
|
||||
ret = dict()
|
||||
if self._aws_access_key_id != "" and self._aws_secret_access_key != "":
|
||||
ret["aws_access_key_id"] = self._aws_access_key_id
|
||||
ret["aws_secret_access_key"] = self._aws_secret_access_key
|
||||
# We do not necessarily need session token...
|
||||
if self._aws_session_token != "":
|
||||
ret["aws_session_token"] = self._aws_session_token
|
||||
self._logger.debug("IAM credentials from custom config.")
|
||||
return ret
|
||||
|
||||
def createWebsocketEndpoint(self, host, port, region, method, awsServiceName, path):
|
||||
# Return the endpoint as unicode string in 3.x
|
||||
# Gather all the facts
|
||||
amazonDate = self._createAmazonDate()
|
||||
amazonDateSimple = amazonDate[0] # Unicode in 3.x
|
||||
amazonDateComplex = amazonDate[1] # Unicode in 3.x
|
||||
allKeys = self._checkIAMCredentials() # Unicode in 3.x
|
||||
if not self._hasCredentialsNecessaryForWebsocket(allKeys):
|
||||
raise wssNoKeyInEnvironmentError()
|
||||
else:
|
||||
# Because of self._hasCredentialsNecessaryForWebsocket(...), keyID and secretKey should not be None from here
|
||||
keyID = allKeys["aws_access_key_id"]
|
||||
secretKey = allKeys["aws_secret_access_key"]
|
||||
# amazonDateSimple and amazonDateComplex are guaranteed not to be None
|
||||
queryParameters = "X-Amz-Algorithm=AWS4-HMAC-SHA256" + \
|
||||
"&X-Amz-Credential=" + keyID + "%2F" + amazonDateSimple + "%2F" + region + "%2F" + awsServiceName + "%2Faws4_request" + \
|
||||
"&X-Amz-Date=" + amazonDateComplex + \
|
||||
"&X-Amz-Expires=86400" + \
|
||||
"&X-Amz-SignedHeaders=host" # Unicode in 3.x
|
||||
hashedPayload = hashlib.sha256(str("").encode('utf-8')).hexdigest() # Unicode in 3.x
|
||||
# Create the string to sign
|
||||
signedHeaders = "host"
|
||||
canonicalHeaders = "host:" + host + "\n"
|
||||
canonicalRequest = method + "\n" + path + "\n" + queryParameters + "\n" + canonicalHeaders + "\n" + signedHeaders + "\n" + hashedPayload # Unicode in 3.x
|
||||
hashedCanonicalRequest = hashlib.sha256(str(canonicalRequest).encode('utf-8')).hexdigest() # Unicoede in 3.x
|
||||
stringToSign = "AWS4-HMAC-SHA256\n" + amazonDateComplex + "\n" + amazonDateSimple + "/" + region + "/" + awsServiceName + "/aws4_request\n" + hashedCanonicalRequest # Unicode in 3.x
|
||||
# Sign it
|
||||
signingKey = self._getSignatureKey(secretKey, amazonDateSimple, region, awsServiceName)
|
||||
signature = hmac.new(signingKey, (stringToSign).encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
# generate url
|
||||
url = "wss://" + host + ":" + str(port) + path + '?' + queryParameters + "&X-Amz-Signature=" + signature
|
||||
# See if we have STS token, if we do, add it
|
||||
awsSessionTokenCandidate = allKeys.get("aws_session_token")
|
||||
if awsSessionTokenCandidate is not None and len(awsSessionTokenCandidate) != 0:
|
||||
aws_session_token = allKeys["aws_session_token"]
|
||||
url += "&X-Amz-Security-Token=" + quote(aws_session_token.encode("utf-8")) # Unicode in 3.x
|
||||
self._logger.debug("createWebsocketEndpoint: Websocket URL: " + url)
|
||||
return url
|
||||
|
||||
def _hasCredentialsNecessaryForWebsocket(self, allKeys):
|
||||
awsAccessKeyIdCandidate = allKeys.get("aws_access_key_id")
|
||||
awsSecretAccessKeyCandidate = allKeys.get("aws_secret_access_key")
|
||||
# None value is NOT considered as valid entries
|
||||
validEntries = awsAccessKeyIdCandidate is not None and awsAccessKeyIdCandidate is not None
|
||||
if validEntries:
|
||||
# Empty value is NOT considered as valid entries
|
||||
validEntries &= (len(awsAccessKeyIdCandidate) != 0 and len(awsSecretAccessKeyCandidate) != 0)
|
||||
return validEntries
|
||||
|
||||
|
||||
# This is an internal class that buffers the incoming bytes into an
|
||||
# internal buffer until it gets the full desired length of bytes.
|
||||
# At that time, this bufferedReader will be reset.
|
||||
# *Error handling:
|
||||
# For retry errors (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, EAGAIN),
|
||||
# leave them to the paho _packet_read for further handling (ignored and try
|
||||
# again when data is available.
|
||||
# For other errors, leave them to the paho _packet_read for error reporting.
|
||||
|
||||
|
||||
class _BufferedReader:
|
||||
_sslSocket = None
|
||||
_internalBuffer = None
|
||||
_remainedLength = -1
|
||||
_bufferingInProgress = False
|
||||
|
||||
def __init__(self, sslSocket):
|
||||
self._sslSocket = sslSocket
|
||||
self._internalBuffer = bytearray()
|
||||
self._bufferingInProgress = False
|
||||
|
||||
def _reset(self):
|
||||
self._internalBuffer = bytearray()
|
||||
self._remainedLength = -1
|
||||
self._bufferingInProgress = False
|
||||
|
||||
def read(self, numberOfBytesToBeBuffered):
|
||||
if not self._bufferingInProgress: # If last read is completed...
|
||||
self._remainedLength = numberOfBytesToBeBuffered
|
||||
self._bufferingInProgress = True # Now we start buffering a new length of bytes
|
||||
|
||||
while self._remainedLength > 0: # Read in a loop, always try to read in the remained length
|
||||
# If the data is temporarily not available, socket.error will be raised and catched by paho
|
||||
dataChunk = self._sslSocket.read(self._remainedLength)
|
||||
# There is a chance where the server terminates the connection without closing the socket.
|
||||
# If that happens, let's raise an exception and enter the reconnect flow.
|
||||
if not dataChunk:
|
||||
raise socket.error(errno.ECONNABORTED, 0)
|
||||
self._internalBuffer.extend(dataChunk) # Buffer the data
|
||||
self._remainedLength -= len(dataChunk) # Update the remained length
|
||||
|
||||
# The requested length of bytes is buffered, recover the context and return it
|
||||
# Otherwise error should be raised
|
||||
ret = self._internalBuffer
|
||||
self._reset()
|
||||
return ret # This should always be bytearray
|
||||
|
||||
|
||||
# This is the internal class that sends requested data out chunk by chunk according
|
||||
# to the availablity of the socket write operation. If the requested bytes of data
|
||||
# (after encoding) needs to be sent out in separate socket write operations (most
|
||||
# probably be interrupted by the error socket.error (errno = ssl.SSL_ERROR_WANT_WRITE).)
|
||||
# , the write pointer is stored to ensure that the continued bytes will be sent next
|
||||
# time this function gets called.
|
||||
# *Error handling:
|
||||
# For retry errors (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, EAGAIN),
|
||||
# leave them to the paho _packet_read for further handling (ignored and try
|
||||
# again when data is available.
|
||||
# For other errors, leave them to the paho _packet_read for error reporting.
|
||||
|
||||
|
||||
class _BufferedWriter:
|
||||
_sslSocket = None
|
||||
_internalBuffer = None
|
||||
_writingInProgress = False
|
||||
_requestedDataLength = -1
|
||||
|
||||
def __init__(self, sslSocket):
|
||||
self._sslSocket = sslSocket
|
||||
self._internalBuffer = bytearray()
|
||||
self._writingInProgress = False
|
||||
self._requestedDataLength = -1
|
||||
|
||||
def _reset(self):
|
||||
self._internalBuffer = bytearray()
|
||||
self._writingInProgress = False
|
||||
self._requestedDataLength = -1
|
||||
|
||||
# Input data for this function needs to be an encoded wss frame
|
||||
# Always request for packet[pos=0:] (raw MQTT data)
|
||||
def write(self, encodedData, payloadLength):
|
||||
# encodedData should always be bytearray
|
||||
# Check if we have a frame that is partially sent
|
||||
if not self._writingInProgress:
|
||||
self._internalBuffer = encodedData
|
||||
self._writingInProgress = True
|
||||
self._requestedDataLength = payloadLength
|
||||
# Now, write as much as we can
|
||||
lengthWritten = self._sslSocket.write(self._internalBuffer)
|
||||
self._internalBuffer = self._internalBuffer[lengthWritten:]
|
||||
# This MQTT packet has been sent out in a wss frame, completely
|
||||
if len(self._internalBuffer) == 0:
|
||||
ret = self._requestedDataLength
|
||||
self._reset()
|
||||
return ret
|
||||
# This socket write is half-baked...
|
||||
else:
|
||||
return 0 # Ensure that the 'pos' inside the MQTT packet never moves since we have not finished the transmission of this encoded frame
|
||||
|
||||
|
||||
class SecuredWebSocketCore:
|
||||
# Websocket Constants
|
||||
_OP_CONTINUATION = 0x0
|
||||
_OP_TEXT = 0x1
|
||||
_OP_BINARY = 0x2
|
||||
_OP_CONNECTION_CLOSE = 0x8
|
||||
_OP_PING = 0x9
|
||||
_OP_PONG = 0xa
|
||||
# Websocket Connect Status
|
||||
_WebsocketConnectInit = -1
|
||||
_WebsocketDisconnected = 1
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, socket, hostAddress, portNumber, AWSAccessKeyID="", AWSSecretAccessKey="", AWSSessionToken=""):
|
||||
self._connectStatus = self._WebsocketConnectInit
|
||||
# Handlers
|
||||
self._sslSocket = socket
|
||||
self._sigV4Handler = self._createSigV4Core()
|
||||
self._sigV4Handler.setIAMCredentials(AWSAccessKeyID, AWSSecretAccessKey, AWSSessionToken)
|
||||
# Endpoint Info
|
||||
self._hostAddress = hostAddress
|
||||
self._portNumber = portNumber
|
||||
# Section Flags
|
||||
self._hasOpByte = False
|
||||
self._hasPayloadLengthFirst = False
|
||||
self._hasPayloadLengthExtended = False
|
||||
self._hasMaskKey = False
|
||||
self._hasPayload = False
|
||||
# Properties for current websocket frame
|
||||
self._isFIN = False
|
||||
self._RSVBits = None
|
||||
self._opCode = None
|
||||
self._needMaskKey = False
|
||||
self._payloadLengthBytesLength = 1
|
||||
self._payloadLength = 0
|
||||
self._maskKey = None
|
||||
self._payloadDataBuffer = bytearray() # Once the whole wss connection is lost, there is no need to keep the buffered payload
|
||||
try:
|
||||
self._handShake(hostAddress, portNumber)
|
||||
except wssNoKeyInEnvironmentError: # Handle SigV4 signing and websocket handshaking errors
|
||||
raise ValueError("No Access Key/KeyID Error")
|
||||
except wssHandShakeError:
|
||||
raise ValueError("Websocket Handshake Error")
|
||||
except ClientError as e:
|
||||
raise ValueError(e.message)
|
||||
# Now we have a socket with secured websocket...
|
||||
self._bufferedReader = _BufferedReader(self._sslSocket)
|
||||
self._bufferedWriter = _BufferedWriter(self._sslSocket)
|
||||
|
||||
def _createSigV4Core(self):
|
||||
return SigV4Core()
|
||||
|
||||
def _generateMaskKey(self):
|
||||
return bytearray(os.urandom(4))
|
||||
# os.urandom returns ascii str in 2.x, converted to bytearray
|
||||
# os.urandom returns bytes in 3.x, converted to bytearray
|
||||
|
||||
def _reset(self): # Reset the context for wss frame reception
|
||||
# Control info
|
||||
self._hasOpByte = False
|
||||
self._hasPayloadLengthFirst = False
|
||||
self._hasPayloadLengthExtended = False
|
||||
self._hasMaskKey = False
|
||||
self._hasPayload = False
|
||||
# Frame Info
|
||||
self._isFIN = False
|
||||
self._RSVBits = None
|
||||
self._opCode = None
|
||||
self._needMaskKey = False
|
||||
self._payloadLengthBytesLength = 1
|
||||
self._payloadLength = 0
|
||||
self._maskKey = None
|
||||
# Never reset the payloadData since we might have fragmented MQTT data from the pervious frame
|
||||
|
||||
def _generateWSSKey(self):
|
||||
return base64.b64encode(os.urandom(128)) # Bytes
|
||||
|
||||
def _verifyWSSResponse(self, response, clientKey):
|
||||
# Check if it is a 101 response
|
||||
rawResponse = response.strip().lower()
|
||||
if b"101 switching protocols" not in rawResponse or b"upgrade: websocket" not in rawResponse or b"connection: upgrade" not in rawResponse:
|
||||
return False
|
||||
# Parse out the sec-websocket-accept
|
||||
WSSAcceptKeyIndex = response.strip().index(b"sec-websocket-accept: ") + len(b"sec-websocket-accept: ")
|
||||
rawSecWebSocketAccept = response.strip()[WSSAcceptKeyIndex:].split(b"\r\n")[0].strip()
|
||||
# Verify the WSSAcceptKey
|
||||
return self._verifyWSSAcceptKey(rawSecWebSocketAccept, clientKey)
|
||||
|
||||
def _verifyWSSAcceptKey(self, srcAcceptKey, clientKey):
|
||||
GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
verifyServerAcceptKey = base64.b64encode((hashlib.sha1(clientKey + GUID)).digest()) # Bytes
|
||||
return srcAcceptKey == verifyServerAcceptKey
|
||||
|
||||
def _handShake(self, hostAddress, portNumber):
|
||||
CRLF = "\r\n"
|
||||
IOT_ENDPOINT_PATTERN = r"^[0-9a-zA-Z]+(\.ats|-ats)?\.iot\.(.*)\.amazonaws\..*"
|
||||
matched = re.compile(IOT_ENDPOINT_PATTERN, re.IGNORECASE).match(hostAddress)
|
||||
if not matched:
|
||||
raise ClientError("Invalid endpoint pattern for wss: %s" % hostAddress)
|
||||
region = matched.group(2)
|
||||
signedURL = self._sigV4Handler.createWebsocketEndpoint(hostAddress, portNumber, region, "GET", "iotdata", "/mqtt")
|
||||
# Now we got a signedURL
|
||||
path = signedURL[signedURL.index("/mqtt"):]
|
||||
# Assemble HTTP request headers
|
||||
Method = "GET " + path + " HTTP/1.1" + CRLF
|
||||
Host = "Host: " + hostAddress + CRLF
|
||||
Connection = "Connection: " + "Upgrade" + CRLF
|
||||
Upgrade = "Upgrade: " + "websocket" + CRLF
|
||||
secWebSocketVersion = "Sec-WebSocket-Version: " + "13" + CRLF
|
||||
rawSecWebSocketKey = self._generateWSSKey() # Bytes
|
||||
secWebSocketKey = "sec-websocket-key: " + rawSecWebSocketKey.decode('utf-8') + CRLF # Should be randomly generated...
|
||||
secWebSocketProtocol = "Sec-WebSocket-Protocol: " + "mqttv3.1" + CRLF
|
||||
secWebSocketExtensions = "Sec-WebSocket-Extensions: " + "permessage-deflate; client_max_window_bits" + CRLF
|
||||
# Send the HTTP request
|
||||
# Ensure that we are sending bytes, not by any chance unicode string
|
||||
handshakeBytes = Method + Host + Connection + Upgrade + secWebSocketVersion + secWebSocketProtocol + secWebSocketExtensions + secWebSocketKey + CRLF
|
||||
handshakeBytes = handshakeBytes.encode('utf-8')
|
||||
self._sslSocket.write(handshakeBytes)
|
||||
# Read it back (Non-blocking socket)
|
||||
timeStart = time.time()
|
||||
wssHandshakeResponse = bytearray()
|
||||
while len(wssHandshakeResponse) == 0:
|
||||
try:
|
||||
wssHandshakeResponse += self._sslSocket.read(1024) # Response is always less than 1024 bytes
|
||||
except socket.error as err:
|
||||
if err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE:
|
||||
if time.time() - timeStart > self._getTimeoutSec():
|
||||
raise err # We make sure that reconnect gets retried in Paho upon a wss reconnect response timeout
|
||||
else:
|
||||
raise err
|
||||
# Verify response
|
||||
# Now both wssHandshakeResponse and rawSecWebSocketKey are byte strings
|
||||
if not self._verifyWSSResponse(wssHandshakeResponse, rawSecWebSocketKey):
|
||||
raise wssHandShakeError()
|
||||
else:
|
||||
pass
|
||||
|
||||
def _getTimeoutSec(self):
|
||||
return DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC
|
||||
|
||||
# Used to create a single wss frame
|
||||
# Assume that the maximum length of a MQTT packet never exceeds the maximum length
|
||||
# for a wss frame. Therefore, the FIN bit for the encoded frame will always be 1.
|
||||
# Frames are encoded as BINARY frames.
|
||||
def _encodeFrame(self, rawPayload, opCode, masked=1):
|
||||
ret = bytearray()
|
||||
# Op byte
|
||||
opByte = 0x80 | opCode # Always a FIN, no RSV bits
|
||||
ret.append(opByte)
|
||||
# Payload Length bytes
|
||||
maskBit = masked
|
||||
payloadLength = len(rawPayload)
|
||||
if payloadLength <= 125:
|
||||
ret.append((maskBit << 7) | payloadLength)
|
||||
elif payloadLength <= 0xffff: # 16-bit unsigned int
|
||||
ret.append((maskBit << 7) | 126)
|
||||
ret.extend(struct.pack("!H", payloadLength))
|
||||
elif payloadLength <= 0x7fffffffffffffff: # 64-bit unsigned int (most significant bit must be 0)
|
||||
ret.append((maskBit << 7) | 127)
|
||||
ret.extend(struct.pack("!Q", payloadLength))
|
||||
else: # Overflow
|
||||
raise ValueError("Exceeds the maximum number of bytes for a single websocket frame.")
|
||||
if maskBit == 1:
|
||||
# Mask key bytes
|
||||
maskKey = self._generateMaskKey()
|
||||
ret.extend(maskKey)
|
||||
# Mask the payload
|
||||
payloadBytes = bytearray(rawPayload)
|
||||
if maskBit == 1:
|
||||
for i in range(0, payloadLength):
|
||||
payloadBytes[i] ^= maskKey[i % 4]
|
||||
ret.extend(payloadBytes)
|
||||
# Return the assembled wss frame
|
||||
return ret
|
||||
|
||||
# Used for the wss client to close a wss connection
|
||||
# Create and send a masked wss closing frame
|
||||
def _closeWssConnection(self):
|
||||
# Frames sent from client to server must be masked
|
||||
self._sslSocket.write(self._encodeFrame(b"", self._OP_CONNECTION_CLOSE, masked=1))
|
||||
|
||||
# Used for the wss client to respond to a wss PING from server
|
||||
# Create and send a masked PONG frame
|
||||
def _sendPONG(self):
|
||||
# Frames sent from client to server must be masked
|
||||
self._sslSocket.write(self._encodeFrame(b"", self._OP_PONG, masked=1))
|
||||
|
||||
# Override sslSocket read. Always read from the wss internal payload buffer, which
|
||||
# contains the masked MQTT packet. This read will decode ONE wss frame every time
|
||||
# and load in the payload for MQTT _packet_read. At any time, MQTT _packet_read
|
||||
# should be able to read a complete MQTT packet from the payload (buffered per wss
|
||||
# frame payload). If the MQTT packet is break into separate wss frames, different
|
||||
# chunks will be buffered in separate frames and MQTT _packet_read will not be able
|
||||
# to collect a complete MQTT packet to operate on until the necessary payload is
|
||||
# fully buffered.
|
||||
# If the requested number of bytes are not available, SSL_ERROR_WANT_READ will be
|
||||
# raised to trigger another call of _packet_read when the data is available again.
|
||||
def read(self, numberOfBytes):
|
||||
# Check if we have enough data for paho
|
||||
# _payloadDataBuffer will not be empty ony when the payload of a new wss frame
|
||||
# has been unmasked.
|
||||
if len(self._payloadDataBuffer) >= numberOfBytes:
|
||||
ret = self._payloadDataBuffer[0:numberOfBytes]
|
||||
self._payloadDataBuffer = self._payloadDataBuffer[numberOfBytes:]
|
||||
# struct.unpack(fmt, string) # Py2.x
|
||||
# struct.unpack(fmt, buffer) # Py3.x
|
||||
# Here ret is always in bytes (buffer interface)
|
||||
if sys.version_info[0] < 3: # Py2.x
|
||||
ret = str(ret)
|
||||
return ret
|
||||
# Emmm, We don't. Try to buffer from the socket (It's a new wss frame).
|
||||
if not self._hasOpByte: # Check if we need to buffer OpByte
|
||||
opByte = self._bufferedReader.read(1)
|
||||
self._isFIN = (opByte[0] & 0x80) == 0x80
|
||||
self._RSVBits = (opByte[0] & 0x70)
|
||||
self._opCode = (opByte[0] & 0x0f)
|
||||
self._hasOpByte = True # Finished buffering opByte
|
||||
# Check if any of the RSV bits are set, if so, close the connection
|
||||
# since client never sends negotiated extensions
|
||||
if self._RSVBits != 0x0:
|
||||
self._closeWssConnection()
|
||||
self._connectStatus = self._WebsocketDisconnected
|
||||
self._payloadDataBuffer = bytearray()
|
||||
raise socket.error(ssl.SSL_ERROR_WANT_READ, "RSV bits set with NO negotiated extensions.")
|
||||
if not self._hasPayloadLengthFirst: # Check if we need to buffer First Payload Length byte
|
||||
payloadLengthFirst = self._bufferedReader.read(1)
|
||||
self._hasPayloadLengthFirst = True # Finished buffering first byte of payload length
|
||||
self._needMaskKey = (payloadLengthFirst[0] & 0x80) == 0x80
|
||||
payloadLengthFirstByteArray = bytearray()
|
||||
payloadLengthFirstByteArray.extend(payloadLengthFirst)
|
||||
self._payloadLength = (payloadLengthFirstByteArray[0] & 0x7f)
|
||||
|
||||
if self._payloadLength == 126:
|
||||
self._payloadLengthBytesLength = 2
|
||||
self._hasPayloadLengthExtended = False # Force to buffer the extended
|
||||
elif self._payloadLength == 127:
|
||||
self._payloadLengthBytesLength = 8
|
||||
self._hasPayloadLengthExtended = False # Force to buffer the extended
|
||||
else: # _payloadLength <= 125:
|
||||
self._hasPayloadLengthExtended = True # No need to buffer extended payload length
|
||||
if not self._hasPayloadLengthExtended: # Check if we need to buffer Extended Payload Length bytes
|
||||
payloadLengthExtended = self._bufferedReader.read(self._payloadLengthBytesLength)
|
||||
self._hasPayloadLengthExtended = True
|
||||
if sys.version_info[0] < 3:
|
||||
payloadLengthExtended = str(payloadLengthExtended)
|
||||
if self._payloadLengthBytesLength == 2:
|
||||
self._payloadLength = struct.unpack("!H", payloadLengthExtended)[0]
|
||||
else: # _payloadLengthBytesLength == 8
|
||||
self._payloadLength = struct.unpack("!Q", payloadLengthExtended)[0]
|
||||
|
||||
if self._needMaskKey: # Response from server is masked, close the connection
|
||||
self._closeWssConnection()
|
||||
self._connectStatus = self._WebsocketDisconnected
|
||||
self._payloadDataBuffer = bytearray()
|
||||
raise socket.error(ssl.SSL_ERROR_WANT_READ, "Server response masked, closing connection and try again.")
|
||||
|
||||
if not self._hasPayload: # Check if we need to buffer the payload
|
||||
payloadForThisFrame = self._bufferedReader.read(self._payloadLength)
|
||||
self._hasPayload = True
|
||||
# Client side should never received a masked packet from the server side
|
||||
# Unmask it as needed
|
||||
#if self._needMaskKey:
|
||||
# for i in range(0, self._payloadLength):
|
||||
# payloadForThisFrame[i] ^= self._maskKey[i % 4]
|
||||
# Append it to the internal payload buffer
|
||||
self._payloadDataBuffer.extend(payloadForThisFrame)
|
||||
# Now we have the complete wss frame, reset the context
|
||||
# Check to see if it is a wss closing frame
|
||||
if self._opCode == self._OP_CONNECTION_CLOSE:
|
||||
self._connectStatus = self._WebsocketDisconnected
|
||||
self._payloadDataBuffer = bytearray() # Ensure that once the wss closing frame comes, we have nothing to read and start all over again
|
||||
# Check to see if it is a wss PING frame
|
||||
if self._opCode == self._OP_PING:
|
||||
self._sendPONG() # Nothing more to do here, if the transmission of the last wssMQTT packet is not finished, it will continue
|
||||
self._reset()
|
||||
# Check again if we have enough data for paho
|
||||
if len(self._payloadDataBuffer) >= numberOfBytes:
|
||||
ret = self._payloadDataBuffer[0:numberOfBytes]
|
||||
self._payloadDataBuffer = self._payloadDataBuffer[numberOfBytes:]
|
||||
# struct.unpack(fmt, string) # Py2.x
|
||||
# struct.unpack(fmt, buffer) # Py3.x
|
||||
# Here ret is always in bytes (buffer interface)
|
||||
if sys.version_info[0] < 3: # Py2.x
|
||||
ret = str(ret)
|
||||
return ret
|
||||
else: # Fragmented MQTT packets in separate wss frames
|
||||
raise socket.error(ssl.SSL_ERROR_WANT_READ, "Not a complete MQTT packet payload within this wss frame.")
|
||||
|
||||
def write(self, bytesToBeSent):
|
||||
# When there is a disconnection, select will report a TypeError which triggers the reconnect.
|
||||
# In reconnect, Paho will set the socket object (mocked by wss) to None, blocking other ops
|
||||
# before a connection is re-established.
|
||||
# This 'low-level' socket write op should always be able to write to plain socket.
|
||||
# Error reporting is performed by Python socket itself.
|
||||
# Wss closing frame handling is performed in the wss read.
|
||||
return self._bufferedWriter.write(self._encodeFrame(bytesToBeSent, self._OP_BINARY, 1), len(bytesToBeSent))
|
||||
|
||||
def close(self):
|
||||
if self._sslSocket is not None:
|
||||
self._sslSocket.close()
|
||||
self._sslSocket = None
|
||||
|
||||
def getpeercert(self):
|
||||
return self._sslSocket.getpeercert()
|
||||
|
||||
def getSSLSocket(self):
|
||||
if self._connectStatus != self._WebsocketDisconnected:
|
||||
return self._sslSocket
|
||||
else:
|
||||
return None # Leave the sslSocket to Paho to close it. (_ssl.close() -> wssCore.close())
|
||||
@@ -0,0 +1,244 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import ssl
|
||||
import logging
|
||||
from threading import Lock
|
||||
from numbers import Number
|
||||
import AWSIoTPythonSDK.core.protocol.paho.client as mqtt
|
||||
from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_SUCCESS
|
||||
from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids
|
||||
|
||||
|
||||
class ClientStatus(object):
|
||||
|
||||
IDLE = 0
|
||||
CONNECT = 1
|
||||
RESUBSCRIBE = 2
|
||||
DRAINING = 3
|
||||
STABLE = 4
|
||||
USER_DISCONNECT = 5
|
||||
ABNORMAL_DISCONNECT = 6
|
||||
|
||||
|
||||
class ClientStatusContainer(object):
|
||||
|
||||
def __init__(self):
|
||||
self._status = ClientStatus.IDLE
|
||||
|
||||
def get_status(self):
|
||||
return self._status
|
||||
|
||||
def set_status(self, status):
|
||||
if ClientStatus.USER_DISCONNECT == self._status: # If user requests to disconnect, no status updates other than user connect
|
||||
if ClientStatus.CONNECT == status:
|
||||
self._status = status
|
||||
else:
|
||||
self._status = status
|
||||
|
||||
|
||||
class InternalAsyncMqttClient(object):
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, client_id, clean_session, protocol, use_wss):
|
||||
self._paho_client = self._create_paho_client(client_id, clean_session, None, protocol, use_wss)
|
||||
self._use_wss = use_wss
|
||||
self._event_callback_map_lock = Lock()
|
||||
self._event_callback_map = dict()
|
||||
|
||||
def _create_paho_client(self, client_id, clean_session, user_data, protocol, use_wss):
|
||||
self._logger.debug("Initializing MQTT layer...")
|
||||
return mqtt.Client(client_id, clean_session, user_data, protocol, use_wss)
|
||||
|
||||
# TODO: Merge credentials providers configuration into one
|
||||
def set_cert_credentials_provider(self, cert_credentials_provider):
|
||||
# History issue from Yun SDK where AR9331 embedded Linux only have Python 2.7.3
|
||||
# pre-installed. In this version, TLSv1_2 is not even an option.
|
||||
# SSLv23 is a work-around which selects the highest TLS version between the client
|
||||
# and service. If user installs opensslv1.0.1+, this option will work fine for Mutual
|
||||
# Auth.
|
||||
# Note that we cannot force TLSv1.2 for Mutual Auth. in Python 2.7.3 and TLS support
|
||||
# in Python only starts from Python2.7.
|
||||
# See also: https://docs.python.org/2/library/ssl.html#ssl.PROTOCOL_SSLv23
|
||||
if self._use_wss:
|
||||
ca_path = cert_credentials_provider.get_ca_path()
|
||||
self._paho_client.tls_set(ca_certs=ca_path, cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_SSLv23)
|
||||
else:
|
||||
ca_path = cert_credentials_provider.get_ca_path()
|
||||
cert_path = cert_credentials_provider.get_cert_path()
|
||||
key_path = cert_credentials_provider.get_key_path()
|
||||
self._paho_client.tls_set(ca_certs=ca_path,certfile=cert_path, keyfile=key_path,
|
||||
cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_SSLv23)
|
||||
|
||||
def set_iam_credentials_provider(self, iam_credentials_provider):
|
||||
self._paho_client.configIAMCredentials(iam_credentials_provider.get_access_key_id(),
|
||||
iam_credentials_provider.get_secret_access_key(),
|
||||
iam_credentials_provider.get_session_token())
|
||||
|
||||
def set_endpoint_provider(self, endpoint_provider):
|
||||
self._endpoint_provider = endpoint_provider
|
||||
|
||||
def configure_last_will(self, topic, payload, qos, retain=False):
|
||||
self._paho_client.will_set(topic, payload, qos, retain)
|
||||
|
||||
def configure_alpn_protocols(self, alpn_protocols):
|
||||
self._paho_client.config_alpn_protocols(alpn_protocols)
|
||||
|
||||
def clear_last_will(self):
|
||||
self._paho_client.will_clear()
|
||||
|
||||
def set_username_password(self, username, password=None):
|
||||
self._paho_client.username_pw_set(username, password)
|
||||
|
||||
def set_socket_factory(self, socket_factory):
|
||||
self._paho_client.socket_factory_set(socket_factory)
|
||||
|
||||
def configure_reconnect_back_off(self, base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec):
|
||||
self._paho_client.setBackoffTiming(base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec)
|
||||
|
||||
def connect(self, keep_alive_sec, ack_callback=None):
|
||||
host = self._endpoint_provider.get_host()
|
||||
port = self._endpoint_provider.get_port()
|
||||
|
||||
with self._event_callback_map_lock:
|
||||
self._logger.debug("Filling in fixed event callbacks: CONNACK, DISCONNECT, MESSAGE")
|
||||
self._event_callback_map[FixedEventMids.CONNACK_MID] = self._create_combined_on_connect_callback(ack_callback)
|
||||
self._event_callback_map[FixedEventMids.DISCONNECT_MID] = self._create_combined_on_disconnect_callback(None)
|
||||
self._event_callback_map[FixedEventMids.MESSAGE_MID] = self._create_converted_on_message_callback()
|
||||
|
||||
rc = self._paho_client.connect(host, port, keep_alive_sec)
|
||||
if MQTT_ERR_SUCCESS == rc:
|
||||
self.start_background_network_io()
|
||||
|
||||
return rc
|
||||
|
||||
def start_background_network_io(self):
|
||||
self._logger.debug("Starting network I/O thread...")
|
||||
self._paho_client.loop_start()
|
||||
|
||||
def stop_background_network_io(self):
|
||||
self._logger.debug("Stopping network I/O thread...")
|
||||
self._paho_client.loop_stop()
|
||||
|
||||
def disconnect(self, ack_callback=None):
|
||||
with self._event_callback_map_lock:
|
||||
rc = self._paho_client.disconnect()
|
||||
if MQTT_ERR_SUCCESS == rc:
|
||||
self._logger.debug("Filling in custom disconnect event callback...")
|
||||
combined_on_disconnect_callback = self._create_combined_on_disconnect_callback(ack_callback)
|
||||
self._event_callback_map[FixedEventMids.DISCONNECT_MID] = combined_on_disconnect_callback
|
||||
return rc
|
||||
|
||||
def _create_combined_on_connect_callback(self, ack_callback):
|
||||
def combined_on_connect_callback(mid, data):
|
||||
self.on_online()
|
||||
if ack_callback:
|
||||
ack_callback(mid, data)
|
||||
return combined_on_connect_callback
|
||||
|
||||
def _create_combined_on_disconnect_callback(self, ack_callback):
|
||||
def combined_on_disconnect_callback(mid, data):
|
||||
self.on_offline()
|
||||
if ack_callback:
|
||||
ack_callback(mid, data)
|
||||
return combined_on_disconnect_callback
|
||||
|
||||
def _create_converted_on_message_callback(self):
|
||||
def converted_on_message_callback(mid, data):
|
||||
self.on_message(data)
|
||||
return converted_on_message_callback
|
||||
|
||||
# For client online notification
|
||||
def on_online(self):
|
||||
pass
|
||||
|
||||
# For client offline notification
|
||||
def on_offline(self):
|
||||
pass
|
||||
|
||||
# For client message reception notification
|
||||
def on_message(self, message):
|
||||
pass
|
||||
|
||||
def publish(self, topic, payload, qos, retain=False, ack_callback=None):
|
||||
with self._event_callback_map_lock:
|
||||
rc, mid = self._paho_client.publish(topic, payload, qos, retain)
|
||||
if MQTT_ERR_SUCCESS == rc and qos > 0 and ack_callback:
|
||||
self._logger.debug("Filling in custom puback (QoS>0) event callback...")
|
||||
self._event_callback_map[mid] = ack_callback
|
||||
return rc, mid
|
||||
|
||||
def subscribe(self, topic, qos, ack_callback=None):
|
||||
with self._event_callback_map_lock:
|
||||
rc, mid = self._paho_client.subscribe(topic, qos)
|
||||
if MQTT_ERR_SUCCESS == rc and ack_callback:
|
||||
self._logger.debug("Filling in custom suback event callback...")
|
||||
self._event_callback_map[mid] = ack_callback
|
||||
return rc, mid
|
||||
|
||||
def unsubscribe(self, topic, ack_callback=None):
|
||||
with self._event_callback_map_lock:
|
||||
rc, mid = self._paho_client.unsubscribe(topic)
|
||||
if MQTT_ERR_SUCCESS == rc and ack_callback:
|
||||
self._logger.debug("Filling in custom unsuback event callback...")
|
||||
self._event_callback_map[mid] = ack_callback
|
||||
return rc, mid
|
||||
|
||||
def register_internal_event_callbacks(self, on_connect, on_disconnect, on_publish, on_subscribe, on_unsubscribe, on_message):
|
||||
self._logger.debug("Registering internal event callbacks to MQTT layer...")
|
||||
self._paho_client.on_connect = on_connect
|
||||
self._paho_client.on_disconnect = on_disconnect
|
||||
self._paho_client.on_publish = on_publish
|
||||
self._paho_client.on_subscribe = on_subscribe
|
||||
self._paho_client.on_unsubscribe = on_unsubscribe
|
||||
self._paho_client.on_message = on_message
|
||||
|
||||
def unregister_internal_event_callbacks(self):
|
||||
self._logger.debug("Unregistering internal event callbacks from MQTT layer...")
|
||||
self._paho_client.on_connect = None
|
||||
self._paho_client.on_disconnect = None
|
||||
self._paho_client.on_publish = None
|
||||
self._paho_client.on_subscribe = None
|
||||
self._paho_client.on_unsubscribe = None
|
||||
self._paho_client.on_message = None
|
||||
|
||||
def invoke_event_callback(self, mid, data=None):
|
||||
with self._event_callback_map_lock:
|
||||
event_callback = self._event_callback_map.get(mid)
|
||||
# For invoking the event callback, we do not need to acquire the lock
|
||||
if event_callback:
|
||||
self._logger.debug("Invoking custom event callback...")
|
||||
if data is not None:
|
||||
event_callback(mid=mid, data=data)
|
||||
else:
|
||||
event_callback(mid=mid)
|
||||
if isinstance(mid, Number): # Do NOT remove callbacks for CONNACK/DISCONNECT/MESSAGE
|
||||
self._logger.debug("This custom event callback is for pub/sub/unsub, removing it after invocation...")
|
||||
with self._event_callback_map_lock:
|
||||
del self._event_callback_map[mid]
|
||||
|
||||
def remove_event_callback(self, mid):
|
||||
with self._event_callback_map_lock:
|
||||
if mid in self._event_callback_map:
|
||||
self._logger.debug("Removing custom event callback...")
|
||||
del self._event_callback_map[mid]
|
||||
|
||||
def clean_up_event_callbacks(self):
|
||||
with self._event_callback_map_lock:
|
||||
self._event_callback_map.clear()
|
||||
|
||||
def get_event_callback_map(self):
|
||||
return self._event_callback_map
|
||||
@@ -0,0 +1,20 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC = 30
|
||||
DEFAULT_OPERATION_TIMEOUT_SEC = 5
|
||||
DEFAULT_DRAINING_INTERNAL_SEC = 0.5
|
||||
METRICS_PREFIX = "?SDK=Python&Version="
|
||||
ALPN_PROTCOLS = "x-amzn-mqtt-ca"
|
||||
@@ -0,0 +1,29 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
class EventTypes(object):
|
||||
CONNACK = 0
|
||||
DISCONNECT = 1
|
||||
PUBACK = 2
|
||||
SUBACK = 3
|
||||
UNSUBACK = 4
|
||||
MESSAGE = 5
|
||||
|
||||
|
||||
class FixedEventMids(object):
|
||||
CONNACK_MID = "CONNECTED"
|
||||
DISCONNECT_MID = "DISCONNECTED"
|
||||
MESSAGE_MID = "MESSAGE"
|
||||
QUEUED_MID = "QUEUED"
|
||||
@@ -0,0 +1,87 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import logging
|
||||
from AWSIoTPythonSDK.core.util.enums import DropBehaviorTypes
|
||||
|
||||
|
||||
class AppendResults(object):
|
||||
APPEND_FAILURE_QUEUE_FULL = -1
|
||||
APPEND_FAILURE_QUEUE_DISABLED = -2
|
||||
APPEND_SUCCESS = 0
|
||||
|
||||
|
||||
class OfflineRequestQueue(list):
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, max_size, drop_behavior=DropBehaviorTypes.DROP_NEWEST):
|
||||
if not isinstance(max_size, int) or not isinstance(drop_behavior, int):
|
||||
self._logger.error("init: MaximumSize/DropBehavior must be integer.")
|
||||
raise TypeError("MaximumSize/DropBehavior must be integer.")
|
||||
if drop_behavior != DropBehaviorTypes.DROP_OLDEST and drop_behavior != DropBehaviorTypes.DROP_NEWEST:
|
||||
self._logger.error("init: Drop behavior not supported.")
|
||||
raise ValueError("Drop behavior not supported.")
|
||||
|
||||
list.__init__([])
|
||||
self._drop_behavior = drop_behavior
|
||||
# When self._maximumSize > 0, queue is limited
|
||||
# When self._maximumSize == 0, queue is disabled
|
||||
# When self._maximumSize < 0. queue is infinite
|
||||
self._max_size = max_size
|
||||
|
||||
def _is_enabled(self):
|
||||
return self._max_size != 0
|
||||
|
||||
def _need_drop_messages(self):
|
||||
# Need to drop messages when:
|
||||
# 1. Queue is limited and full
|
||||
# 2. Queue is disabled
|
||||
is_queue_full = len(self) >= self._max_size
|
||||
is_queue_limited = self._max_size > 0
|
||||
is_queue_disabled = not self._is_enabled()
|
||||
return (is_queue_full and is_queue_limited) or is_queue_disabled
|
||||
|
||||
def set_behavior_drop_newest(self):
|
||||
self._drop_behavior = DropBehaviorTypes.DROP_NEWEST
|
||||
|
||||
def set_behavior_drop_oldest(self):
|
||||
self._drop_behavior = DropBehaviorTypes.DROP_OLDEST
|
||||
|
||||
# Override
|
||||
# Append to a queue with a limited size.
|
||||
# Return APPEND_SUCCESS if the append is successful
|
||||
# Return APPEND_FAILURE_QUEUE_FULL if the append failed because the queue is full
|
||||
# Return APPEND_FAILURE_QUEUE_DISABLED if the append failed because the queue is disabled
|
||||
def append(self, data):
|
||||
ret = AppendResults.APPEND_SUCCESS
|
||||
if self._is_enabled():
|
||||
if self._need_drop_messages():
|
||||
# We should drop the newest
|
||||
if DropBehaviorTypes.DROP_NEWEST == self._drop_behavior:
|
||||
self._logger.warn("append: Full queue. Drop the newest: " + str(data))
|
||||
ret = AppendResults.APPEND_FAILURE_QUEUE_FULL
|
||||
# We should drop the oldest
|
||||
else:
|
||||
current_oldest = super(OfflineRequestQueue, self).pop(0)
|
||||
self._logger.warn("append: Full queue. Drop the oldest: " + str(current_oldest))
|
||||
super(OfflineRequestQueue, self).append(data)
|
||||
ret = AppendResults.APPEND_FAILURE_QUEUE_FULL
|
||||
else:
|
||||
self._logger.debug("append: Add new element: " + str(data))
|
||||
super(OfflineRequestQueue, self).append(data)
|
||||
else:
|
||||
self._logger.debug("append: Queue is disabled. Drop the message: " + str(data))
|
||||
ret = AppendResults.APPEND_FAILURE_QUEUE_DISABLED
|
||||
return ret
|
||||
@@ -0,0 +1,27 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
class RequestTypes(object):
|
||||
CONNECT = 0
|
||||
DISCONNECT = 1
|
||||
PUBLISH = 2
|
||||
SUBSCRIBE = 3
|
||||
UNSUBSCRIBE = 4
|
||||
|
||||
class QueueableRequest(object):
|
||||
|
||||
def __init__(self, type, data):
|
||||
self.type = type
|
||||
self.data = data # Can be a tuple
|
||||
@@ -0,0 +1,296 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import time
|
||||
import logging
|
||||
from threading import Thread
|
||||
from threading import Event
|
||||
from AWSIoTPythonSDK.core.protocol.internal.events import EventTypes
|
||||
from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids
|
||||
from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus
|
||||
from AWSIoTPythonSDK.core.protocol.internal.queues import OfflineRequestQueue
|
||||
from AWSIoTPythonSDK.core.protocol.internal.requests import RequestTypes
|
||||
from AWSIoTPythonSDK.core.protocol.paho.client import topic_matches_sub
|
||||
from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_DRAINING_INTERNAL_SEC
|
||||
|
||||
|
||||
class EventProducer(object):
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, cv, event_queue):
|
||||
self._cv = cv
|
||||
self._event_queue = event_queue
|
||||
|
||||
def on_connect(self, client, user_data, flags, rc):
|
||||
self._add_to_queue(FixedEventMids.CONNACK_MID, EventTypes.CONNACK, rc)
|
||||
self._logger.debug("Produced [connack] event")
|
||||
|
||||
def on_disconnect(self, client, user_data, rc):
|
||||
self._add_to_queue(FixedEventMids.DISCONNECT_MID, EventTypes.DISCONNECT, rc)
|
||||
self._logger.debug("Produced [disconnect] event")
|
||||
|
||||
def on_publish(self, client, user_data, mid):
|
||||
self._add_to_queue(mid, EventTypes.PUBACK, None)
|
||||
self._logger.debug("Produced [puback] event")
|
||||
|
||||
def on_subscribe(self, client, user_data, mid, granted_qos):
|
||||
self._add_to_queue(mid, EventTypes.SUBACK, granted_qos)
|
||||
self._logger.debug("Produced [suback] event")
|
||||
|
||||
def on_unsubscribe(self, client, user_data, mid):
|
||||
self._add_to_queue(mid, EventTypes.UNSUBACK, None)
|
||||
self._logger.debug("Produced [unsuback] event")
|
||||
|
||||
def on_message(self, client, user_data, message):
|
||||
self._add_to_queue(FixedEventMids.MESSAGE_MID, EventTypes.MESSAGE, message)
|
||||
self._logger.debug("Produced [message] event")
|
||||
|
||||
def _add_to_queue(self, mid, event_type, data):
|
||||
with self._cv:
|
||||
self._event_queue.put((mid, event_type, data))
|
||||
self._cv.notify()
|
||||
|
||||
|
||||
class EventConsumer(object):
|
||||
|
||||
MAX_DISPATCH_INTERNAL_SEC = 0.01
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, cv, event_queue, internal_async_client,
|
||||
subscription_manager, offline_requests_manager, client_status):
|
||||
self._cv = cv
|
||||
self._event_queue = event_queue
|
||||
self._internal_async_client = internal_async_client
|
||||
self._subscription_manager = subscription_manager
|
||||
self._offline_requests_manager = offline_requests_manager
|
||||
self._client_status = client_status
|
||||
self._is_running = False
|
||||
self._draining_interval_sec = DEFAULT_DRAINING_INTERNAL_SEC
|
||||
self._dispatch_methods = {
|
||||
EventTypes.CONNACK : self._dispatch_connack,
|
||||
EventTypes.DISCONNECT : self._dispatch_disconnect,
|
||||
EventTypes.PUBACK : self._dispatch_puback,
|
||||
EventTypes.SUBACK : self._dispatch_suback,
|
||||
EventTypes.UNSUBACK : self._dispatch_unsuback,
|
||||
EventTypes.MESSAGE : self._dispatch_message
|
||||
}
|
||||
self._offline_request_handlers = {
|
||||
RequestTypes.PUBLISH : self._handle_offline_publish,
|
||||
RequestTypes.SUBSCRIBE : self._handle_offline_subscribe,
|
||||
RequestTypes.UNSUBSCRIBE : self._handle_offline_unsubscribe
|
||||
}
|
||||
self._stopper = Event()
|
||||
|
||||
def update_offline_requests_manager(self, offline_requests_manager):
|
||||
self._offline_requests_manager = offline_requests_manager
|
||||
|
||||
def update_draining_interval_sec(self, draining_interval_sec):
|
||||
self._draining_interval_sec = draining_interval_sec
|
||||
|
||||
def get_draining_interval_sec(self):
|
||||
return self._draining_interval_sec
|
||||
|
||||
def is_running(self):
|
||||
return self._is_running
|
||||
|
||||
def start(self):
|
||||
self._stopper.clear()
|
||||
self._is_running = True
|
||||
dispatch_events = Thread(target=self._dispatch)
|
||||
dispatch_events.daemon = True
|
||||
dispatch_events.start()
|
||||
self._logger.debug("Event consuming thread started")
|
||||
|
||||
def stop(self):
|
||||
if self._is_running:
|
||||
self._is_running = False
|
||||
self._clean_up()
|
||||
self._logger.debug("Event consuming thread stopped")
|
||||
|
||||
def _clean_up(self):
|
||||
self._logger.debug("Cleaning up before stopping event consuming")
|
||||
with self._event_queue.mutex:
|
||||
self._event_queue.queue.clear()
|
||||
self._logger.debug("Event queue cleared")
|
||||
self._internal_async_client.stop_background_network_io()
|
||||
self._logger.debug("Network thread stopped")
|
||||
self._internal_async_client.clean_up_event_callbacks()
|
||||
self._logger.debug("Event callbacks cleared")
|
||||
|
||||
def wait_until_it_stops(self, timeout_sec):
|
||||
self._logger.debug("Waiting for event consumer to completely stop")
|
||||
return self._stopper.wait(timeout=timeout_sec)
|
||||
|
||||
def is_fully_stopped(self):
|
||||
return self._stopper.is_set()
|
||||
|
||||
def _dispatch(self):
|
||||
while self._is_running:
|
||||
with self._cv:
|
||||
if self._event_queue.empty():
|
||||
self._cv.wait(self.MAX_DISPATCH_INTERNAL_SEC)
|
||||
else:
|
||||
while not self._event_queue.empty():
|
||||
self._dispatch_one()
|
||||
self._stopper.set()
|
||||
self._logger.debug("Exiting dispatching loop...")
|
||||
|
||||
def _dispatch_one(self):
|
||||
mid, event_type, data = self._event_queue.get()
|
||||
if mid:
|
||||
self._dispatch_methods[event_type](mid, data)
|
||||
self._internal_async_client.invoke_event_callback(mid, data=data)
|
||||
# We need to make sure disconnect event gets dispatched and then we stop the consumer
|
||||
if self._need_to_stop_dispatching(mid):
|
||||
self.stop()
|
||||
|
||||
def _need_to_stop_dispatching(self, mid):
|
||||
status = self._client_status.get_status()
|
||||
return (ClientStatus.USER_DISCONNECT == status or ClientStatus.CONNECT == status) \
|
||||
and mid == FixedEventMids.DISCONNECT_MID
|
||||
|
||||
def _dispatch_connack(self, mid, rc):
|
||||
status = self._client_status.get_status()
|
||||
self._logger.debug("Dispatching [connack] event")
|
||||
if self._need_recover():
|
||||
if ClientStatus.STABLE != status: # To avoid multiple connack dispatching
|
||||
self._logger.debug("Has recovery job")
|
||||
clean_up_debt = Thread(target=self._clean_up_debt)
|
||||
clean_up_debt.start()
|
||||
else:
|
||||
self._logger.debug("No need for recovery")
|
||||
self._client_status.set_status(ClientStatus.STABLE)
|
||||
|
||||
def _need_recover(self):
|
||||
return self._subscription_manager.list_records() or self._offline_requests_manager.has_more()
|
||||
|
||||
def _clean_up_debt(self):
|
||||
self._handle_resubscribe()
|
||||
self._handle_draining()
|
||||
self._client_status.set_status(ClientStatus.STABLE)
|
||||
|
||||
def _handle_resubscribe(self):
|
||||
subscriptions = self._subscription_manager.list_records()
|
||||
if subscriptions and not self._has_user_disconnect_request():
|
||||
self._logger.debug("Start resubscribing")
|
||||
self._client_status.set_status(ClientStatus.RESUBSCRIBE)
|
||||
for topic, (qos, message_callback, ack_callback) in subscriptions:
|
||||
if self._has_user_disconnect_request():
|
||||
self._logger.debug("User disconnect detected")
|
||||
break
|
||||
self._internal_async_client.subscribe(topic, qos, ack_callback)
|
||||
|
||||
def _handle_draining(self):
|
||||
if self._offline_requests_manager.has_more() and not self._has_user_disconnect_request():
|
||||
self._logger.debug("Start draining")
|
||||
self._client_status.set_status(ClientStatus.DRAINING)
|
||||
while self._offline_requests_manager.has_more():
|
||||
if self._has_user_disconnect_request():
|
||||
self._logger.debug("User disconnect detected")
|
||||
break
|
||||
offline_request = self._offline_requests_manager.get_next()
|
||||
if offline_request:
|
||||
self._offline_request_handlers[offline_request.type](offline_request)
|
||||
time.sleep(self._draining_interval_sec)
|
||||
|
||||
def _has_user_disconnect_request(self):
|
||||
return ClientStatus.USER_DISCONNECT == self._client_status.get_status()
|
||||
|
||||
def _dispatch_disconnect(self, mid, rc):
|
||||
self._logger.debug("Dispatching [disconnect] event")
|
||||
status = self._client_status.get_status()
|
||||
if ClientStatus.USER_DISCONNECT == status or ClientStatus.CONNECT == status:
|
||||
pass
|
||||
else:
|
||||
self._client_status.set_status(ClientStatus.ABNORMAL_DISCONNECT)
|
||||
|
||||
# For puback, suback and unsuback, ack callback invocation is handled in dispatch_one
|
||||
# Do nothing in the event dispatching itself
|
||||
def _dispatch_puback(self, mid, rc):
|
||||
self._logger.debug("Dispatching [puback] event")
|
||||
|
||||
def _dispatch_suback(self, mid, rc):
|
||||
self._logger.debug("Dispatching [suback] event")
|
||||
|
||||
def _dispatch_unsuback(self, mid, rc):
|
||||
self._logger.debug("Dispatching [unsuback] event")
|
||||
|
||||
def _dispatch_message(self, mid, message):
|
||||
self._logger.debug("Dispatching [message] event")
|
||||
subscriptions = self._subscription_manager.list_records()
|
||||
if subscriptions:
|
||||
for topic, (qos, message_callback, _) in subscriptions:
|
||||
if topic_matches_sub(topic, message.topic) and message_callback:
|
||||
message_callback(None, None, message) # message_callback(client, userdata, message)
|
||||
|
||||
def _handle_offline_publish(self, request):
|
||||
topic, payload, qos, retain = request.data
|
||||
self._internal_async_client.publish(topic, payload, qos, retain)
|
||||
self._logger.debug("Processed offline publish request")
|
||||
|
||||
def _handle_offline_subscribe(self, request):
|
||||
topic, qos, message_callback, ack_callback = request.data
|
||||
self._subscription_manager.add_record(topic, qos, message_callback, ack_callback)
|
||||
self._internal_async_client.subscribe(topic, qos, ack_callback)
|
||||
self._logger.debug("Processed offline subscribe request")
|
||||
|
||||
def _handle_offline_unsubscribe(self, request):
|
||||
topic, ack_callback = request.data
|
||||
self._subscription_manager.remove_record(topic)
|
||||
self._internal_async_client.unsubscribe(topic, ack_callback)
|
||||
self._logger.debug("Processed offline unsubscribe request")
|
||||
|
||||
|
||||
class SubscriptionManager(object):
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self):
|
||||
self._subscription_map = dict()
|
||||
|
||||
def add_record(self, topic, qos, message_callback, ack_callback):
|
||||
self._logger.debug("Adding a new subscription record: %s qos: %d", topic, qos)
|
||||
self._subscription_map[topic] = qos, message_callback, ack_callback # message_callback and/or ack_callback could be None
|
||||
|
||||
def remove_record(self, topic):
|
||||
self._logger.debug("Removing subscription record: %s", topic)
|
||||
if self._subscription_map.get(topic): # Ignore topics that are never subscribed to
|
||||
del self._subscription_map[topic]
|
||||
else:
|
||||
self._logger.warn("Removing attempt for non-exist subscription record: %s", topic)
|
||||
|
||||
def list_records(self):
|
||||
return list(self._subscription_map.items())
|
||||
|
||||
|
||||
class OfflineRequestsManager(object):
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, max_size, drop_behavior):
|
||||
self._queue = OfflineRequestQueue(max_size, drop_behavior)
|
||||
|
||||
def has_more(self):
|
||||
return len(self._queue) > 0
|
||||
|
||||
def add_one(self, request):
|
||||
return self._queue.append(request)
|
||||
|
||||
def get_next(self):
|
||||
if self.has_more():
|
||||
return self._queue.pop(0)
|
||||
else:
|
||||
return None
|
||||
@@ -0,0 +1,373 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import AWSIoTPythonSDK
|
||||
from AWSIoTPythonSDK.core.protocol.internal.clients import InternalAsyncMqttClient
|
||||
from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatusContainer
|
||||
from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus
|
||||
from AWSIoTPythonSDK.core.protocol.internal.workers import EventProducer
|
||||
from AWSIoTPythonSDK.core.protocol.internal.workers import EventConsumer
|
||||
from AWSIoTPythonSDK.core.protocol.internal.workers import SubscriptionManager
|
||||
from AWSIoTPythonSDK.core.protocol.internal.workers import OfflineRequestsManager
|
||||
from AWSIoTPythonSDK.core.protocol.internal.requests import RequestTypes
|
||||
from AWSIoTPythonSDK.core.protocol.internal.requests import QueueableRequest
|
||||
from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC
|
||||
from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_OPERATION_TIMEOUT_SEC
|
||||
from AWSIoTPythonSDK.core.protocol.internal.defaults import METRICS_PREFIX
|
||||
from AWSIoTPythonSDK.core.protocol.internal.defaults import ALPN_PROTCOLS
|
||||
from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids
|
||||
from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_SUCCESS
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import connectError
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import connectTimeoutException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import disconnectError
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import disconnectTimeoutException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishTimeoutException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueFullException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueDisabledException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeQueueFullException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeQueueDisabledException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeQueueFullException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeQueueDisabledException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeError
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeTimeoutException
|
||||
from AWSIoTPythonSDK.core.protocol.internal.queues import AppendResults
|
||||
from AWSIoTPythonSDK.core.util.enums import DropBehaviorTypes
|
||||
from AWSIoTPythonSDK.core.protocol.paho.client import MQTTv31
|
||||
from threading import Condition
|
||||
from threading import Event
|
||||
import logging
|
||||
import sys
|
||||
if sys.version_info[0] < 3:
|
||||
from Queue import Queue
|
||||
else:
|
||||
from queue import Queue
|
||||
|
||||
|
||||
class MqttCore(object):
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, client_id, clean_session, protocol, use_wss):
|
||||
self._use_wss = use_wss
|
||||
self._username = ""
|
||||
self._password = None
|
||||
self._enable_metrics_collection = True
|
||||
self._event_queue = Queue()
|
||||
self._event_cv = Condition()
|
||||
self._event_producer = EventProducer(self._event_cv, self._event_queue)
|
||||
self._client_status = ClientStatusContainer()
|
||||
self._internal_async_client = InternalAsyncMqttClient(client_id, clean_session, protocol, use_wss)
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
self._offline_requests_manager = OfflineRequestsManager(-1, DropBehaviorTypes.DROP_NEWEST) # Infinite queue
|
||||
self._event_consumer = EventConsumer(self._event_cv,
|
||||
self._event_queue,
|
||||
self._internal_async_client,
|
||||
self._subscription_manager,
|
||||
self._offline_requests_manager,
|
||||
self._client_status)
|
||||
self._connect_disconnect_timeout_sec = DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC
|
||||
self._operation_timeout_sec = DEFAULT_OPERATION_TIMEOUT_SEC
|
||||
self._init_offline_request_exceptions()
|
||||
self._init_workers()
|
||||
self._logger.info("MqttCore initialized")
|
||||
self._logger.info("Client id: %s" % client_id)
|
||||
self._logger.info("Protocol version: %s" % ("MQTTv3.1" if protocol == MQTTv31 else "MQTTv3.1.1"))
|
||||
self._logger.info("Authentication type: %s" % ("SigV4 WebSocket" if use_wss else "TLSv1.2 certificate based Mutual Auth."))
|
||||
|
||||
def _init_offline_request_exceptions(self):
|
||||
self._offline_request_queue_disabled_exceptions = {
|
||||
RequestTypes.PUBLISH : publishQueueDisabledException(),
|
||||
RequestTypes.SUBSCRIBE : subscribeQueueDisabledException(),
|
||||
RequestTypes.UNSUBSCRIBE : unsubscribeQueueDisabledException()
|
||||
}
|
||||
self._offline_request_queue_full_exceptions = {
|
||||
RequestTypes.PUBLISH : publishQueueFullException(),
|
||||
RequestTypes.SUBSCRIBE : subscribeQueueFullException(),
|
||||
RequestTypes.UNSUBSCRIBE : unsubscribeQueueFullException()
|
||||
}
|
||||
|
||||
def _init_workers(self):
|
||||
self._internal_async_client.register_internal_event_callbacks(self._event_producer.on_connect,
|
||||
self._event_producer.on_disconnect,
|
||||
self._event_producer.on_publish,
|
||||
self._event_producer.on_subscribe,
|
||||
self._event_producer.on_unsubscribe,
|
||||
self._event_producer.on_message)
|
||||
|
||||
def _start_workers(self):
|
||||
self._event_consumer.start()
|
||||
|
||||
def use_wss(self):
|
||||
return self._use_wss
|
||||
|
||||
# Used for general message event reception
|
||||
def on_message(self, message):
|
||||
pass
|
||||
|
||||
# Used for general online event notification
|
||||
def on_online(self):
|
||||
pass
|
||||
|
||||
# Used for general offline event notification
|
||||
def on_offline(self):
|
||||
pass
|
||||
|
||||
def configure_cert_credentials(self, cert_credentials_provider):
|
||||
self._logger.info("Configuring certificates...")
|
||||
self._internal_async_client.set_cert_credentials_provider(cert_credentials_provider)
|
||||
|
||||
def configure_iam_credentials(self, iam_credentials_provider):
|
||||
self._logger.info("Configuring custom IAM credentials...")
|
||||
self._internal_async_client.set_iam_credentials_provider(iam_credentials_provider)
|
||||
|
||||
def configure_endpoint(self, endpoint_provider):
|
||||
self._logger.info("Configuring endpoint...")
|
||||
self._internal_async_client.set_endpoint_provider(endpoint_provider)
|
||||
|
||||
def configure_connect_disconnect_timeout_sec(self, connect_disconnect_timeout_sec):
|
||||
self._logger.info("Configuring connect/disconnect time out: %f sec" % connect_disconnect_timeout_sec)
|
||||
self._connect_disconnect_timeout_sec = connect_disconnect_timeout_sec
|
||||
|
||||
def configure_operation_timeout_sec(self, operation_timeout_sec):
|
||||
self._logger.info("Configuring MQTT operation time out: %f sec" % operation_timeout_sec)
|
||||
self._operation_timeout_sec = operation_timeout_sec
|
||||
|
||||
def configure_reconnect_back_off(self, base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec):
|
||||
self._logger.info("Configuring reconnect back off timing...")
|
||||
self._logger.info("Base quiet time: %f sec" % base_reconnect_quiet_sec)
|
||||
self._logger.info("Max quiet time: %f sec" % max_reconnect_quiet_sec)
|
||||
self._logger.info("Stable connection time: %f sec" % stable_connection_sec)
|
||||
self._internal_async_client.configure_reconnect_back_off(base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec)
|
||||
|
||||
def configure_alpn_protocols(self):
|
||||
self._logger.info("Configuring alpn protocols...")
|
||||
self._internal_async_client.configure_alpn_protocols([ALPN_PROTCOLS])
|
||||
|
||||
def configure_last_will(self, topic, payload, qos, retain=False):
|
||||
self._logger.info("Configuring last will...")
|
||||
self._internal_async_client.configure_last_will(topic, payload, qos, retain)
|
||||
|
||||
def clear_last_will(self):
|
||||
self._logger.info("Clearing last will...")
|
||||
self._internal_async_client.clear_last_will()
|
||||
|
||||
def configure_username_password(self, username, password=None):
|
||||
self._logger.info("Configuring username and password...")
|
||||
self._username = username
|
||||
self._password = password
|
||||
|
||||
def configure_socket_factory(self, socket_factory):
|
||||
self._logger.info("Configuring socket factory...")
|
||||
self._internal_async_client.set_socket_factory(socket_factory)
|
||||
|
||||
def enable_metrics_collection(self):
|
||||
self._enable_metrics_collection = True
|
||||
|
||||
def disable_metrics_collection(self):
|
||||
self._enable_metrics_collection = False
|
||||
|
||||
def configure_offline_requests_queue(self, max_size, drop_behavior):
|
||||
self._logger.info("Configuring offline requests queueing: max queue size: %d", max_size)
|
||||
self._offline_requests_manager = OfflineRequestsManager(max_size, drop_behavior)
|
||||
self._event_consumer.update_offline_requests_manager(self._offline_requests_manager)
|
||||
|
||||
def configure_draining_interval_sec(self, draining_interval_sec):
|
||||
self._logger.info("Configuring offline requests queue draining interval: %f sec", draining_interval_sec)
|
||||
self._event_consumer.update_draining_interval_sec(draining_interval_sec)
|
||||
|
||||
def connect(self, keep_alive_sec):
|
||||
self._logger.info("Performing sync connect...")
|
||||
event = Event()
|
||||
self.connect_async(keep_alive_sec, self._create_blocking_ack_callback(event))
|
||||
if not event.wait(self._connect_disconnect_timeout_sec):
|
||||
self._logger.error("Connect timed out")
|
||||
raise connectTimeoutException()
|
||||
return True
|
||||
|
||||
def connect_async(self, keep_alive_sec, ack_callback=None):
|
||||
self._logger.info("Performing async connect...")
|
||||
self._logger.info("Keep-alive: %f sec" % keep_alive_sec)
|
||||
self._start_workers()
|
||||
self._load_callbacks()
|
||||
self._load_username_password()
|
||||
|
||||
try:
|
||||
self._client_status.set_status(ClientStatus.CONNECT)
|
||||
rc = self._internal_async_client.connect(keep_alive_sec, ack_callback)
|
||||
if MQTT_ERR_SUCCESS != rc:
|
||||
self._logger.error("Connect error: %d", rc)
|
||||
raise connectError(rc)
|
||||
except Exception as e:
|
||||
# Provided any error in connect, we should clean up the threads that have been created
|
||||
self._event_consumer.stop()
|
||||
if not self._event_consumer.wait_until_it_stops(self._connect_disconnect_timeout_sec):
|
||||
self._logger.error("Time out in waiting for event consumer to stop")
|
||||
else:
|
||||
self._logger.debug("Event consumer stopped")
|
||||
self._client_status.set_status(ClientStatus.IDLE)
|
||||
raise e
|
||||
|
||||
return FixedEventMids.CONNACK_MID
|
||||
|
||||
def _load_callbacks(self):
|
||||
self._logger.debug("Passing in general notification callbacks to internal client...")
|
||||
self._internal_async_client.on_online = self.on_online
|
||||
self._internal_async_client.on_offline = self.on_offline
|
||||
self._internal_async_client.on_message = self.on_message
|
||||
|
||||
def _load_username_password(self):
|
||||
username_candidate = self._username
|
||||
if self._enable_metrics_collection:
|
||||
username_candidate += METRICS_PREFIX
|
||||
username_candidate += AWSIoTPythonSDK.__version__
|
||||
self._internal_async_client.set_username_password(username_candidate, self._password)
|
||||
|
||||
def disconnect(self):
|
||||
self._logger.info("Performing sync disconnect...")
|
||||
event = Event()
|
||||
self.disconnect_async(self._create_blocking_ack_callback(event))
|
||||
if not event.wait(self._connect_disconnect_timeout_sec):
|
||||
self._logger.error("Disconnect timed out")
|
||||
raise disconnectTimeoutException()
|
||||
if not self._event_consumer.wait_until_it_stops(self._connect_disconnect_timeout_sec):
|
||||
self._logger.error("Disconnect timed out in waiting for event consumer")
|
||||
raise disconnectTimeoutException()
|
||||
return True
|
||||
|
||||
def disconnect_async(self, ack_callback=None):
|
||||
self._logger.info("Performing async disconnect...")
|
||||
self._client_status.set_status(ClientStatus.USER_DISCONNECT)
|
||||
rc = self._internal_async_client.disconnect(ack_callback)
|
||||
if MQTT_ERR_SUCCESS != rc:
|
||||
self._logger.error("Disconnect error: %d", rc)
|
||||
raise disconnectError(rc)
|
||||
return FixedEventMids.DISCONNECT_MID
|
||||
|
||||
def publish(self, topic, payload, qos, retain=False):
|
||||
self._logger.info("Performing sync publish...")
|
||||
ret = False
|
||||
if ClientStatus.STABLE != self._client_status.get_status():
|
||||
self._handle_offline_request(RequestTypes.PUBLISH, (topic, payload, qos, retain))
|
||||
else:
|
||||
if qos > 0:
|
||||
event = Event()
|
||||
rc, mid = self._publish_async(topic, payload, qos, retain, self._create_blocking_ack_callback(event))
|
||||
if not event.wait(self._operation_timeout_sec):
|
||||
self._internal_async_client.remove_event_callback(mid)
|
||||
self._logger.error("Publish timed out")
|
||||
raise publishTimeoutException()
|
||||
else:
|
||||
self._publish_async(topic, payload, qos, retain)
|
||||
ret = True
|
||||
return ret
|
||||
|
||||
def publish_async(self, topic, payload, qos, retain=False, ack_callback=None):
|
||||
self._logger.info("Performing async publish...")
|
||||
if ClientStatus.STABLE != self._client_status.get_status():
|
||||
self._handle_offline_request(RequestTypes.PUBLISH, (topic, payload, qos, retain))
|
||||
return FixedEventMids.QUEUED_MID
|
||||
else:
|
||||
rc, mid = self._publish_async(topic, payload, qos, retain, ack_callback)
|
||||
return mid
|
||||
|
||||
def _publish_async(self, topic, payload, qos, retain=False, ack_callback=None):
|
||||
rc, mid = self._internal_async_client.publish(topic, payload, qos, retain, ack_callback)
|
||||
if MQTT_ERR_SUCCESS != rc:
|
||||
self._logger.error("Publish error: %d", rc)
|
||||
raise publishError(rc)
|
||||
return rc, mid
|
||||
|
||||
def subscribe(self, topic, qos, message_callback=None):
|
||||
self._logger.info("Performing sync subscribe...")
|
||||
ret = False
|
||||
if ClientStatus.STABLE != self._client_status.get_status():
|
||||
self._handle_offline_request(RequestTypes.SUBSCRIBE, (topic, qos, message_callback, None))
|
||||
else:
|
||||
event = Event()
|
||||
rc, mid = self._subscribe_async(topic, qos, self._create_blocking_ack_callback(event), message_callback)
|
||||
if not event.wait(self._operation_timeout_sec):
|
||||
self._internal_async_client.remove_event_callback(mid)
|
||||
self._logger.error("Subscribe timed out")
|
||||
raise subscribeTimeoutException()
|
||||
ret = True
|
||||
return ret
|
||||
|
||||
def subscribe_async(self, topic, qos, ack_callback=None, message_callback=None):
|
||||
self._logger.info("Performing async subscribe...")
|
||||
if ClientStatus.STABLE != self._client_status.get_status():
|
||||
self._handle_offline_request(RequestTypes.SUBSCRIBE, (topic, qos, message_callback, ack_callback))
|
||||
return FixedEventMids.QUEUED_MID
|
||||
else:
|
||||
rc, mid = self._subscribe_async(topic, qos, ack_callback, message_callback)
|
||||
return mid
|
||||
|
||||
def _subscribe_async(self, topic, qos, ack_callback=None, message_callback=None):
|
||||
self._subscription_manager.add_record(topic, qos, message_callback, ack_callback)
|
||||
rc, mid = self._internal_async_client.subscribe(topic, qos, ack_callback)
|
||||
if MQTT_ERR_SUCCESS != rc:
|
||||
self._logger.error("Subscribe error: %d", rc)
|
||||
raise subscribeError(rc)
|
||||
return rc, mid
|
||||
|
||||
def unsubscribe(self, topic):
|
||||
self._logger.info("Performing sync unsubscribe...")
|
||||
ret = False
|
||||
if ClientStatus.STABLE != self._client_status.get_status():
|
||||
self._handle_offline_request(RequestTypes.UNSUBSCRIBE, (topic, None))
|
||||
else:
|
||||
event = Event()
|
||||
rc, mid = self._unsubscribe_async(topic, self._create_blocking_ack_callback(event))
|
||||
if not event.wait(self._operation_timeout_sec):
|
||||
self._internal_async_client.remove_event_callback(mid)
|
||||
self._logger.error("Unsubscribe timed out")
|
||||
raise unsubscribeTimeoutException()
|
||||
ret = True
|
||||
return ret
|
||||
|
||||
def unsubscribe_async(self, topic, ack_callback=None):
|
||||
self._logger.info("Performing async unsubscribe...")
|
||||
if ClientStatus.STABLE != self._client_status.get_status():
|
||||
self._handle_offline_request(RequestTypes.UNSUBSCRIBE, (topic, ack_callback))
|
||||
return FixedEventMids.QUEUED_MID
|
||||
else:
|
||||
rc, mid = self._unsubscribe_async(topic, ack_callback)
|
||||
return mid
|
||||
|
||||
def _unsubscribe_async(self, topic, ack_callback=None):
|
||||
self._subscription_manager.remove_record(topic)
|
||||
rc, mid = self._internal_async_client.unsubscribe(topic, ack_callback)
|
||||
if MQTT_ERR_SUCCESS != rc:
|
||||
self._logger.error("Unsubscribe error: %d", rc)
|
||||
raise unsubscribeError(rc)
|
||||
return rc, mid
|
||||
|
||||
def _create_blocking_ack_callback(self, event):
|
||||
def ack_callback(mid, data=None):
|
||||
event.set()
|
||||
return ack_callback
|
||||
|
||||
def _handle_offline_request(self, type, data):
|
||||
self._logger.info("Offline request detected!")
|
||||
offline_request = QueueableRequest(type, data)
|
||||
append_result = self._offline_requests_manager.add_one(offline_request)
|
||||
if AppendResults.APPEND_FAILURE_QUEUE_DISABLED == append_result:
|
||||
self._logger.error("Offline request queue has been disabled")
|
||||
raise self._offline_request_queue_disabled_exceptions[type]
|
||||
if AppendResults.APPEND_FAILURE_QUEUE_FULL == append_result:
|
||||
self._logger.error("Offline request queue is full")
|
||||
raise self._offline_request_queue_full_exceptions[type]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,430 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from threading import Timer, Lock, Thread
|
||||
|
||||
|
||||
class _shadowRequestToken:
|
||||
|
||||
URN_PREFIX_LENGTH = 9
|
||||
|
||||
def getNextToken(self):
|
||||
return uuid.uuid4().urn[self.URN_PREFIX_LENGTH:] # We only need the uuid digits, not the urn prefix
|
||||
|
||||
|
||||
class _basicJSONParser:
|
||||
|
||||
def setString(self, srcString):
|
||||
self._rawString = srcString
|
||||
self._dictionObject = None
|
||||
|
||||
def regenerateString(self):
|
||||
return json.dumps(self._dictionaryObject)
|
||||
|
||||
def getAttributeValue(self, srcAttributeKey):
|
||||
return self._dictionaryObject.get(srcAttributeKey)
|
||||
|
||||
def setAttributeValue(self, srcAttributeKey, srcAttributeValue):
|
||||
self._dictionaryObject[srcAttributeKey] = srcAttributeValue
|
||||
|
||||
def validateJSON(self):
|
||||
try:
|
||||
self._dictionaryObject = json.loads(self._rawString)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class deviceShadow:
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, srcShadowName, srcIsPersistentSubscribe, srcShadowManager):
|
||||
"""
|
||||
|
||||
The class that denotes a local/client-side device shadow instance.
|
||||
|
||||
Users can perform shadow operations on this instance to retrieve and modify the
|
||||
corresponding shadow JSON document in AWS IoT Cloud. The following shadow operations
|
||||
are available:
|
||||
|
||||
- Get
|
||||
|
||||
- Update
|
||||
|
||||
- Delete
|
||||
|
||||
- Listen on delta
|
||||
|
||||
- Cancel listening on delta
|
||||
|
||||
This is returned from :code:`AWSIoTPythonSDK.MQTTLib.AWSIoTMQTTShadowClient.createShadowWithName` function call.
|
||||
No need to call directly from user scripts.
|
||||
|
||||
"""
|
||||
if srcShadowName is None or srcIsPersistentSubscribe is None or srcShadowManager is None:
|
||||
raise TypeError("None type inputs detected.")
|
||||
self._shadowName = srcShadowName
|
||||
# Tool handler
|
||||
self._shadowManagerHandler = srcShadowManager
|
||||
self._basicJSONParserHandler = _basicJSONParser()
|
||||
self._tokenHandler = _shadowRequestToken()
|
||||
# Properties
|
||||
self._isPersistentSubscribe = srcIsPersistentSubscribe
|
||||
self._lastVersionInSync = -1 # -1 means not initialized
|
||||
self._isGetSubscribed = False
|
||||
self._isUpdateSubscribed = False
|
||||
self._isDeleteSubscribed = False
|
||||
self._shadowSubscribeCallbackTable = dict()
|
||||
self._shadowSubscribeCallbackTable["get"] = None
|
||||
self._shadowSubscribeCallbackTable["delete"] = None
|
||||
self._shadowSubscribeCallbackTable["update"] = None
|
||||
self._shadowSubscribeCallbackTable["delta"] = None
|
||||
self._shadowSubscribeStatusTable = dict()
|
||||
self._shadowSubscribeStatusTable["get"] = 0
|
||||
self._shadowSubscribeStatusTable["delete"] = 0
|
||||
self._shadowSubscribeStatusTable["update"] = 0
|
||||
self._tokenPool = dict()
|
||||
self._dataStructureLock = Lock()
|
||||
|
||||
def _doNonPersistentUnsubscribe(self, currentAction):
|
||||
self._shadowManagerHandler.basicShadowUnsubscribe(self._shadowName, currentAction)
|
||||
self._logger.info("Unsubscribed to " + currentAction + " accepted/rejected topics for deviceShadow: " + self._shadowName)
|
||||
|
||||
def generalCallback(self, client, userdata, message):
|
||||
# In Py3.x, message.payload comes in as a bytes(string)
|
||||
# json.loads needs a string input
|
||||
with self._dataStructureLock:
|
||||
currentTopic = message.topic
|
||||
currentAction = self._parseTopicAction(currentTopic) # get/delete/update/delta
|
||||
currentType = self._parseTopicType(currentTopic) # accepted/rejected/delta
|
||||
payloadUTF8String = message.payload.decode('utf-8')
|
||||
# get/delete/update: Need to deal with token, timer and unsubscribe
|
||||
if currentAction in ["get", "delete", "update"]:
|
||||
# Check for token
|
||||
self._basicJSONParserHandler.setString(payloadUTF8String)
|
||||
if self._basicJSONParserHandler.validateJSON(): # Filter out invalid JSON
|
||||
currentToken = self._basicJSONParserHandler.getAttributeValue(u"clientToken")
|
||||
if currentToken is not None:
|
||||
self._logger.debug("shadow message clientToken: " + currentToken)
|
||||
if currentToken is not None and currentToken in self._tokenPool.keys(): # Filter out JSON without the desired token
|
||||
# Sync local version when it is an accepted response
|
||||
self._logger.debug("Token is in the pool. Type: " + currentType)
|
||||
if currentType == "accepted":
|
||||
incomingVersion = self._basicJSONParserHandler.getAttributeValue(u"version")
|
||||
# If it is get/update accepted response, we need to sync the local version
|
||||
if incomingVersion is not None and incomingVersion > self._lastVersionInSync and currentAction != "delete":
|
||||
self._lastVersionInSync = incomingVersion
|
||||
# If it is a delete accepted, we need to reset the version
|
||||
else:
|
||||
self._lastVersionInSync = -1 # The version will always be synced for the next incoming delta/GU-accepted response
|
||||
# Cancel the timer and clear the token
|
||||
self._tokenPool[currentToken].cancel()
|
||||
del self._tokenPool[currentToken]
|
||||
# Need to unsubscribe?
|
||||
self._shadowSubscribeStatusTable[currentAction] -= 1
|
||||
if not self._isPersistentSubscribe and self._shadowSubscribeStatusTable.get(currentAction) <= 0:
|
||||
self._shadowSubscribeStatusTable[currentAction] = 0
|
||||
processNonPersistentUnsubscribe = Thread(target=self._doNonPersistentUnsubscribe, args=[currentAction])
|
||||
processNonPersistentUnsubscribe.start()
|
||||
# Custom callback
|
||||
if self._shadowSubscribeCallbackTable.get(currentAction) is not None:
|
||||
processCustomCallback = Thread(target=self._shadowSubscribeCallbackTable[currentAction], args=[payloadUTF8String, currentType, currentToken])
|
||||
processCustomCallback.start()
|
||||
# delta: Watch for version
|
||||
else:
|
||||
currentType += "/" + self._parseTopicShadowName(currentTopic)
|
||||
# Sync local version
|
||||
self._basicJSONParserHandler.setString(payloadUTF8String)
|
||||
if self._basicJSONParserHandler.validateJSON(): # Filter out JSON without version
|
||||
incomingVersion = self._basicJSONParserHandler.getAttributeValue(u"version")
|
||||
if incomingVersion is not None and incomingVersion > self._lastVersionInSync:
|
||||
self._lastVersionInSync = incomingVersion
|
||||
# Custom callback
|
||||
if self._shadowSubscribeCallbackTable.get(currentAction) is not None:
|
||||
processCustomCallback = Thread(target=self._shadowSubscribeCallbackTable[currentAction], args=[payloadUTF8String, currentType, None])
|
||||
processCustomCallback.start()
|
||||
|
||||
def _parseTopicAction(self, srcTopic):
|
||||
ret = None
|
||||
fragments = srcTopic.split('/')
|
||||
if fragments[5] == "delta":
|
||||
ret = "delta"
|
||||
else:
|
||||
ret = fragments[4]
|
||||
return ret
|
||||
|
||||
def _parseTopicType(self, srcTopic):
|
||||
fragments = srcTopic.split('/')
|
||||
return fragments[5]
|
||||
|
||||
def _parseTopicShadowName(self, srcTopic):
|
||||
fragments = srcTopic.split('/')
|
||||
return fragments[2]
|
||||
|
||||
def _timerHandler(self, srcActionName, srcToken):
|
||||
with self._dataStructureLock:
|
||||
# Don't crash if we try to remove an unknown token
|
||||
if srcToken not in self._tokenPool:
|
||||
self._logger.warn('Tried to remove non-existent token from pool: %s' % str(srcToken))
|
||||
return
|
||||
# Remove the token
|
||||
del self._tokenPool[srcToken]
|
||||
# Need to unsubscribe?
|
||||
self._shadowSubscribeStatusTable[srcActionName] -= 1
|
||||
if not self._isPersistentSubscribe and self._shadowSubscribeStatusTable.get(srcActionName) <= 0:
|
||||
self._shadowSubscribeStatusTable[srcActionName] = 0
|
||||
self._shadowManagerHandler.basicShadowUnsubscribe(self._shadowName, srcActionName)
|
||||
# Notify time-out issue
|
||||
if self._shadowSubscribeCallbackTable.get(srcActionName) is not None:
|
||||
self._logger.info("Shadow request with token: " + str(srcToken) + " has timed out.")
|
||||
self._shadowSubscribeCallbackTable[srcActionName]("REQUEST TIME OUT", "timeout", srcToken)
|
||||
|
||||
def shadowGet(self, srcCallback, srcTimeout):
|
||||
"""
|
||||
**Description**
|
||||
|
||||
Retrieve the device shadow JSON document from AWS IoT by publishing an empty JSON document to the
|
||||
corresponding shadow topics. Shadow response topics will be subscribed to receive responses from
|
||||
AWS IoT regarding the result of the get operation. Retrieved shadow JSON document will be available
|
||||
in the registered callback. If no response is received within the provided timeout, a timeout
|
||||
notification will be passed into the registered callback.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Retrieve the shadow JSON document from AWS IoT, with a timeout set to 5 seconds
|
||||
BotShadow.shadowGet(customCallback, 5)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*srcCallback* - Function to be called when the response for this shadow request comes back. Should
|
||||
be in form :code:`customCallback(payload, responseStatus, token)`, where :code:`payload` is the
|
||||
JSON document returned, :code:`responseStatus` indicates whether the request has been accepted,
|
||||
rejected or is a delta message, :code:`token` is the token used for tracing in this request.
|
||||
|
||||
*srcTimeout* - Timeout to determine whether the request is invalid. When a request gets timeout,
|
||||
a timeout notification will be generated and put into the registered callback to notify users.
|
||||
|
||||
**Returns**
|
||||
|
||||
The token used for tracing in this shadow request.
|
||||
|
||||
"""
|
||||
with self._dataStructureLock:
|
||||
# Update callback data structure
|
||||
self._shadowSubscribeCallbackTable["get"] = srcCallback
|
||||
# Update number of pending feedback
|
||||
self._shadowSubscribeStatusTable["get"] += 1
|
||||
# clientToken
|
||||
currentToken = self._tokenHandler.getNextToken()
|
||||
self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["get", currentToken])
|
||||
self._basicJSONParserHandler.setString("{}")
|
||||
self._basicJSONParserHandler.validateJSON()
|
||||
self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken)
|
||||
currentPayload = self._basicJSONParserHandler.regenerateString()
|
||||
# Two subscriptions
|
||||
if not self._isPersistentSubscribe or not self._isGetSubscribed:
|
||||
self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "get", self.generalCallback)
|
||||
self._isGetSubscribed = True
|
||||
self._logger.info("Subscribed to get accepted/rejected topics for deviceShadow: " + self._shadowName)
|
||||
# One publish
|
||||
self._shadowManagerHandler.basicShadowPublish(self._shadowName, "get", currentPayload)
|
||||
# Start the timer
|
||||
self._tokenPool[currentToken].start()
|
||||
return currentToken
|
||||
|
||||
def shadowDelete(self, srcCallback, srcTimeout):
|
||||
"""
|
||||
**Description**
|
||||
|
||||
Delete the device shadow from AWS IoT by publishing an empty JSON document to the corresponding
|
||||
shadow topics. Shadow response topics will be subscribed to receive responses from AWS IoT
|
||||
regarding the result of the get operation. Responses will be available in the registered callback.
|
||||
If no response is received within the provided timeout, a timeout notification will be passed into
|
||||
the registered callback.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Delete the device shadow from AWS IoT, with a timeout set to 5 seconds
|
||||
BotShadow.shadowDelete(customCallback, 5)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*srcCallback* - Function to be called when the response for this shadow request comes back. Should
|
||||
be in form :code:`customCallback(payload, responseStatus, token)`, where :code:`payload` is the
|
||||
JSON document returned, :code:`responseStatus` indicates whether the request has been accepted,
|
||||
rejected or is a delta message, :code:`token` is the token used for tracing in this request.
|
||||
|
||||
*srcTimeout* - Timeout to determine whether the request is invalid. When a request gets timeout,
|
||||
a timeout notification will be generated and put into the registered callback to notify users.
|
||||
|
||||
**Returns**
|
||||
|
||||
The token used for tracing in this shadow request.
|
||||
|
||||
"""
|
||||
with self._dataStructureLock:
|
||||
# Update callback data structure
|
||||
self._shadowSubscribeCallbackTable["delete"] = srcCallback
|
||||
# Update number of pending feedback
|
||||
self._shadowSubscribeStatusTable["delete"] += 1
|
||||
# clientToken
|
||||
currentToken = self._tokenHandler.getNextToken()
|
||||
self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["delete", currentToken])
|
||||
self._basicJSONParserHandler.setString("{}")
|
||||
self._basicJSONParserHandler.validateJSON()
|
||||
self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken)
|
||||
currentPayload = self._basicJSONParserHandler.regenerateString()
|
||||
# Two subscriptions
|
||||
if not self._isPersistentSubscribe or not self._isDeleteSubscribed:
|
||||
self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "delete", self.generalCallback)
|
||||
self._isDeleteSubscribed = True
|
||||
self._logger.info("Subscribed to delete accepted/rejected topics for deviceShadow: " + self._shadowName)
|
||||
# One publish
|
||||
self._shadowManagerHandler.basicShadowPublish(self._shadowName, "delete", currentPayload)
|
||||
# Start the timer
|
||||
self._tokenPool[currentToken].start()
|
||||
return currentToken
|
||||
|
||||
def shadowUpdate(self, srcJSONPayload, srcCallback, srcTimeout):
|
||||
"""
|
||||
**Description**
|
||||
|
||||
Update the device shadow JSON document string from AWS IoT by publishing the provided JSON
|
||||
document to the corresponding shadow topics. Shadow response topics will be subscribed to
|
||||
receive responses from AWS IoT regarding the result of the get operation. Response will be
|
||||
available in the registered callback. If no response is received within the provided timeout,
|
||||
a timeout notification will be passed into the registered callback.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Update the shadow JSON document from AWS IoT, with a timeout set to 5 seconds
|
||||
BotShadow.shadowUpdate(newShadowJSONDocumentString, customCallback, 5)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*srcJSONPayload* - JSON document string used to update shadow JSON document in AWS IoT.
|
||||
|
||||
*srcCallback* - Function to be called when the response for this shadow request comes back. Should
|
||||
be in form :code:`customCallback(payload, responseStatus, token)`, where :code:`payload` is the
|
||||
JSON document returned, :code:`responseStatus` indicates whether the request has been accepted,
|
||||
rejected or is a delta message, :code:`token` is the token used for tracing in this request.
|
||||
|
||||
*srcTimeout* - Timeout to determine whether the request is invalid. When a request gets timeout,
|
||||
a timeout notification will be generated and put into the registered callback to notify users.
|
||||
|
||||
**Returns**
|
||||
|
||||
The token used for tracing in this shadow request.
|
||||
|
||||
"""
|
||||
# Validate JSON
|
||||
self._basicJSONParserHandler.setString(srcJSONPayload)
|
||||
if self._basicJSONParserHandler.validateJSON():
|
||||
with self._dataStructureLock:
|
||||
# clientToken
|
||||
currentToken = self._tokenHandler.getNextToken()
|
||||
self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["update", currentToken])
|
||||
self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken)
|
||||
JSONPayloadWithToken = self._basicJSONParserHandler.regenerateString()
|
||||
# Update callback data structure
|
||||
self._shadowSubscribeCallbackTable["update"] = srcCallback
|
||||
# Update number of pending feedback
|
||||
self._shadowSubscribeStatusTable["update"] += 1
|
||||
# Two subscriptions
|
||||
if not self._isPersistentSubscribe or not self._isUpdateSubscribed:
|
||||
self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "update", self.generalCallback)
|
||||
self._isUpdateSubscribed = True
|
||||
self._logger.info("Subscribed to update accepted/rejected topics for deviceShadow: " + self._shadowName)
|
||||
# One publish
|
||||
self._shadowManagerHandler.basicShadowPublish(self._shadowName, "update", JSONPayloadWithToken)
|
||||
# Start the timer
|
||||
self._tokenPool[currentToken].start()
|
||||
else:
|
||||
raise ValueError("Invalid JSON file.")
|
||||
return currentToken
|
||||
|
||||
def shadowRegisterDeltaCallback(self, srcCallback):
|
||||
"""
|
||||
**Description**
|
||||
|
||||
Listen on delta topics for this device shadow by subscribing to delta topics. Whenever there
|
||||
is a difference between the desired and reported state, the registered callback will be called
|
||||
and the delta payload will be available in the callback.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Listen on delta topics for BotShadow
|
||||
BotShadow.shadowRegisterDeltaCallback(customCallback)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*srcCallback* - Function to be called when the response for this shadow request comes back. Should
|
||||
be in form :code:`customCallback(payload, responseStatus, token)`, where :code:`payload` is the
|
||||
JSON document returned, :code:`responseStatus` indicates whether the request has been accepted,
|
||||
rejected or is a delta message, :code:`token` is the token used for tracing in this request.
|
||||
|
||||
**Returns**
|
||||
|
||||
None
|
||||
|
||||
"""
|
||||
with self._dataStructureLock:
|
||||
# Update callback data structure
|
||||
self._shadowSubscribeCallbackTable["delta"] = srcCallback
|
||||
# One subscription
|
||||
self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "delta", self.generalCallback)
|
||||
self._logger.info("Subscribed to delta topic for deviceShadow: " + self._shadowName)
|
||||
|
||||
def shadowUnregisterDeltaCallback(self):
|
||||
"""
|
||||
**Description**
|
||||
|
||||
Cancel listening on delta topics for this device shadow by unsubscribing to delta topics. There will
|
||||
be no delta messages received after this API call even though there is a difference between the
|
||||
desired and reported state.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Cancel listening on delta topics for BotShadow
|
||||
BotShadow.shadowUnregisterDeltaCallback()
|
||||
|
||||
**Parameters**
|
||||
|
||||
None
|
||||
|
||||
**Returns**
|
||||
|
||||
None
|
||||
|
||||
"""
|
||||
with self._dataStructureLock:
|
||||
# Update callback data structure
|
||||
del self._shadowSubscribeCallbackTable["delta"]
|
||||
# One unsubscription
|
||||
self._shadowManagerHandler.basicShadowUnsubscribe(self._shadowName, "delta")
|
||||
self._logger.info("Unsubscribed to delta topics for deviceShadow: " + self._shadowName)
|
||||
@@ -0,0 +1,83 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import logging
|
||||
import time
|
||||
from threading import Lock
|
||||
|
||||
class _shadowAction:
|
||||
_actionType = ["get", "update", "delete", "delta"]
|
||||
|
||||
def __init__(self, srcShadowName, srcActionName):
|
||||
if srcActionName is None or srcActionName not in self._actionType:
|
||||
raise TypeError("Unsupported shadow action.")
|
||||
self._shadowName = srcShadowName
|
||||
self._actionName = srcActionName
|
||||
self.isDelta = srcActionName == "delta"
|
||||
if self.isDelta:
|
||||
self._topicDelta = "$aws/things/" + str(self._shadowName) + "/shadow/update/delta"
|
||||
else:
|
||||
self._topicGeneral = "$aws/things/" + str(self._shadowName) + "/shadow/" + str(self._actionName)
|
||||
self._topicAccept = "$aws/things/" + str(self._shadowName) + "/shadow/" + str(self._actionName) + "/accepted"
|
||||
self._topicReject = "$aws/things/" + str(self._shadowName) + "/shadow/" + str(self._actionName) + "/rejected"
|
||||
|
||||
def getTopicGeneral(self):
|
||||
return self._topicGeneral
|
||||
|
||||
def getTopicAccept(self):
|
||||
return self._topicAccept
|
||||
|
||||
def getTopicReject(self):
|
||||
return self._topicReject
|
||||
|
||||
def getTopicDelta(self):
|
||||
return self._topicDelta
|
||||
|
||||
|
||||
class shadowManager:
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, srcMQTTCore):
|
||||
# Load in mqttCore
|
||||
if srcMQTTCore is None:
|
||||
raise TypeError("None type inputs detected.")
|
||||
self._mqttCoreHandler = srcMQTTCore
|
||||
self._shadowSubUnsubOperationLock = Lock()
|
||||
|
||||
def basicShadowPublish(self, srcShadowName, srcShadowAction, srcPayload):
|
||||
currentShadowAction = _shadowAction(srcShadowName, srcShadowAction)
|
||||
self._mqttCoreHandler.publish(currentShadowAction.getTopicGeneral(), srcPayload, 0, False)
|
||||
|
||||
def basicShadowSubscribe(self, srcShadowName, srcShadowAction, srcCallback):
|
||||
with self._shadowSubUnsubOperationLock:
|
||||
currentShadowAction = _shadowAction(srcShadowName, srcShadowAction)
|
||||
if currentShadowAction.isDelta:
|
||||
self._mqttCoreHandler.subscribe(currentShadowAction.getTopicDelta(), 0, srcCallback)
|
||||
else:
|
||||
self._mqttCoreHandler.subscribe(currentShadowAction.getTopicAccept(), 0, srcCallback)
|
||||
self._mqttCoreHandler.subscribe(currentShadowAction.getTopicReject(), 0, srcCallback)
|
||||
time.sleep(2)
|
||||
|
||||
def basicShadowUnsubscribe(self, srcShadowName, srcShadowAction):
|
||||
with self._shadowSubUnsubOperationLock:
|
||||
currentShadowAction = _shadowAction(srcShadowName, srcShadowAction)
|
||||
if currentShadowAction.isDelta:
|
||||
self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicDelta())
|
||||
else:
|
||||
self._logger.debug(currentShadowAction.getTopicAccept())
|
||||
self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicAccept())
|
||||
self._logger.debug(currentShadowAction.getTopicReject())
|
||||
self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicReject())
|
||||
19
aws-iot-device-sdk-python/AWSIoTPythonSDK/core/util/enums.py
Normal file
19
aws-iot-device-sdk-python/AWSIoTPythonSDK/core/util/enums.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
|
||||
class DropBehaviorTypes(object):
|
||||
DROP_OLDEST = 0
|
||||
DROP_NEWEST = 1
|
||||
@@ -0,0 +1,92 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
|
||||
class CredentialsProvider(object):
|
||||
|
||||
def __init__(self):
|
||||
self._ca_path = ""
|
||||
|
||||
def set_ca_path(self, ca_path):
|
||||
self._ca_path = ca_path
|
||||
|
||||
def get_ca_path(self):
|
||||
return self._ca_path
|
||||
|
||||
|
||||
class CertificateCredentialsProvider(CredentialsProvider):
|
||||
|
||||
def __init__(self):
|
||||
CredentialsProvider.__init__(self)
|
||||
self._cert_path = ""
|
||||
self._key_path = ""
|
||||
|
||||
def set_cert_path(self,cert_path):
|
||||
self._cert_path = cert_path
|
||||
|
||||
def set_key_path(self, key_path):
|
||||
self._key_path = key_path
|
||||
|
||||
def get_cert_path(self):
|
||||
return self._cert_path
|
||||
|
||||
def get_key_path(self):
|
||||
return self._key_path
|
||||
|
||||
|
||||
class IAMCredentialsProvider(CredentialsProvider):
|
||||
|
||||
def __init__(self):
|
||||
CredentialsProvider.__init__(self)
|
||||
self._aws_access_key_id = ""
|
||||
self._aws_secret_access_key = ""
|
||||
self._aws_session_token = ""
|
||||
|
||||
def set_access_key_id(self, access_key_id):
|
||||
self._aws_access_key_id = access_key_id
|
||||
|
||||
def set_secret_access_key(self, secret_access_key):
|
||||
self._aws_secret_access_key = secret_access_key
|
||||
|
||||
def set_session_token(self, session_token):
|
||||
self._aws_session_token = session_token
|
||||
|
||||
def get_access_key_id(self):
|
||||
return self._aws_access_key_id
|
||||
|
||||
def get_secret_access_key(self):
|
||||
return self._aws_secret_access_key
|
||||
|
||||
def get_session_token(self):
|
||||
return self._aws_session_token
|
||||
|
||||
|
||||
class EndpointProvider(object):
|
||||
|
||||
def __init__(self):
|
||||
self._host = ""
|
||||
self._port = -1
|
||||
|
||||
def set_host(self, host):
|
||||
self._host = host
|
||||
|
||||
def set_port(self, port):
|
||||
self._port = port
|
||||
|
||||
def get_host(self):
|
||||
return self._host
|
||||
|
||||
def get_port(self):
|
||||
return self._port
|
||||
@@ -0,0 +1,153 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import AWSIoTPythonSDK.exception.operationTimeoutException as operationTimeoutException
|
||||
import AWSIoTPythonSDK.exception.operationError as operationError
|
||||
|
||||
|
||||
# Serial Exception
|
||||
class acceptTimeoutException(Exception):
|
||||
def __init__(self, msg="Accept Timeout"):
|
||||
self.message = msg
|
||||
|
||||
|
||||
# MQTT Operation Timeout Exception
|
||||
class connectTimeoutException(operationTimeoutException.operationTimeoutException):
|
||||
def __init__(self, msg="Connect Timeout"):
|
||||
self.message = msg
|
||||
|
||||
|
||||
class disconnectTimeoutException(operationTimeoutException.operationTimeoutException):
|
||||
def __init__(self, msg="Disconnect Timeout"):
|
||||
self.message = msg
|
||||
|
||||
|
||||
class publishTimeoutException(operationTimeoutException.operationTimeoutException):
|
||||
def __init__(self, msg="Publish Timeout"):
|
||||
self.message = msg
|
||||
|
||||
|
||||
class subscribeTimeoutException(operationTimeoutException.operationTimeoutException):
|
||||
def __init__(self, msg="Subscribe Timeout"):
|
||||
self.message = msg
|
||||
|
||||
|
||||
class unsubscribeTimeoutException(operationTimeoutException.operationTimeoutException):
|
||||
def __init__(self, msg="Unsubscribe Timeout"):
|
||||
self.message = msg
|
||||
|
||||
|
||||
# MQTT Operation Error
|
||||
class connectError(operationError.operationError):
|
||||
def __init__(self, errorCode):
|
||||
self.message = "Connect Error: " + str(errorCode)
|
||||
|
||||
|
||||
class disconnectError(operationError.operationError):
|
||||
def __init__(self, errorCode):
|
||||
self.message = "Disconnect Error: " + str(errorCode)
|
||||
|
||||
|
||||
class publishError(operationError.operationError):
|
||||
def __init__(self, errorCode):
|
||||
self.message = "Publish Error: " + str(errorCode)
|
||||
|
||||
|
||||
class publishQueueFullException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Internal Publish Queue Full"
|
||||
|
||||
|
||||
class publishQueueDisabledException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Offline publish request dropped because queueing is disabled"
|
||||
|
||||
|
||||
class subscribeError(operationError.operationError):
|
||||
def __init__(self, errorCode):
|
||||
self.message = "Subscribe Error: " + str(errorCode)
|
||||
|
||||
|
||||
class subscribeQueueFullException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Internal Subscribe Queue Full"
|
||||
|
||||
|
||||
class subscribeQueueDisabledException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Offline subscribe request dropped because queueing is disabled"
|
||||
|
||||
|
||||
class unsubscribeError(operationError.operationError):
|
||||
def __init__(self, errorCode):
|
||||
self.message = "Unsubscribe Error: " + str(errorCode)
|
||||
|
||||
|
||||
class unsubscribeQueueFullException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Internal Unsubscribe Queue Full"
|
||||
|
||||
|
||||
class unsubscribeQueueDisabledException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Offline unsubscribe request dropped because queueing is disabled"
|
||||
|
||||
|
||||
# Websocket Error
|
||||
class wssNoKeyInEnvironmentError(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "No AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY detected in $ENV."
|
||||
|
||||
|
||||
class wssHandShakeError(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Error in WSS handshake."
|
||||
|
||||
|
||||
# Greengrass Discovery Error
|
||||
class DiscoveryDataNotFoundException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "No discovery data found"
|
||||
|
||||
|
||||
class DiscoveryTimeoutException(operationTimeoutException.operationTimeoutException):
|
||||
def __init__(self, message="Discovery request timed out"):
|
||||
self.message = message
|
||||
|
||||
|
||||
class DiscoveryInvalidRequestException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Invalid discovery request"
|
||||
|
||||
|
||||
class DiscoveryUnauthorizedException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Discovery request not authorized"
|
||||
|
||||
|
||||
class DiscoveryThrottlingException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Too many discovery requests"
|
||||
|
||||
|
||||
class DiscoveryFailure(operationError.operationError):
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
|
||||
# Client Error
|
||||
class ClientError(Exception):
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
@@ -0,0 +1,19 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
|
||||
class operationError(Exception):
|
||||
def __init__(self, msg="Operation Error"):
|
||||
self.message = msg
|
||||
@@ -0,0 +1,19 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
|
||||
class operationTimeoutException(Exception):
|
||||
def __init__(self, msg="Operation Timeout"):
|
||||
self.message = msg
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,3 @@
|
||||
__version__ = "1.4.8"
|
||||
|
||||
|
||||
@@ -0,0 +1,466 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import json
|
||||
|
||||
|
||||
KEY_GROUP_LIST = "GGGroups"
|
||||
KEY_GROUP_ID = "GGGroupId"
|
||||
KEY_CORE_LIST = "Cores"
|
||||
KEY_CORE_ARN = "thingArn"
|
||||
KEY_CA_LIST = "CAs"
|
||||
KEY_CONNECTIVITY_INFO_LIST = "Connectivity"
|
||||
KEY_CONNECTIVITY_INFO_ID = "Id"
|
||||
KEY_HOST_ADDRESS = "HostAddress"
|
||||
KEY_PORT_NUMBER = "PortNumber"
|
||||
KEY_METADATA = "Metadata"
|
||||
|
||||
|
||||
class ConnectivityInfo(object):
|
||||
"""
|
||||
|
||||
Class the stores one set of the connectivity information.
|
||||
This is the data model for easy access to the discovery information from the discovery request function call. No
|
||||
need to call directly from user scripts.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, id, host, port, metadata):
|
||||
self._id = id
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._metadata = metadata
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
"""
|
||||
|
||||
Connectivity Information Id.
|
||||
|
||||
"""
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
"""
|
||||
|
||||
Host address.
|
||||
|
||||
"""
|
||||
return self._host
|
||||
|
||||
@property
|
||||
def port(self):
|
||||
"""
|
||||
|
||||
Port number.
|
||||
|
||||
"""
|
||||
return self._port
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
"""
|
||||
|
||||
Metadata string.
|
||||
|
||||
"""
|
||||
return self._metadata
|
||||
|
||||
|
||||
class CoreConnectivityInfo(object):
|
||||
"""
|
||||
|
||||
Class that stores the connectivity information for a Greengrass core.
|
||||
This is the data model for easy access to the discovery information from the discovery request function call. No
|
||||
need to call directly from user scripts.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, coreThingArn, groupId):
|
||||
self._core_thing_arn = coreThingArn
|
||||
self._group_id = groupId
|
||||
self._connectivity_info_dict = dict()
|
||||
|
||||
@property
|
||||
def coreThingArn(self):
|
||||
"""
|
||||
|
||||
Thing arn for this Greengrass core.
|
||||
|
||||
"""
|
||||
return self._core_thing_arn
|
||||
|
||||
@property
|
||||
def groupId(self):
|
||||
"""
|
||||
|
||||
Greengrass group id that this Greengrass core belongs to.
|
||||
|
||||
"""
|
||||
return self._group_id
|
||||
|
||||
@property
|
||||
def connectivityInfoList(self):
|
||||
"""
|
||||
|
||||
The list of connectivity information that this Greengrass core has.
|
||||
|
||||
"""
|
||||
return list(self._connectivity_info_dict.values())
|
||||
|
||||
def getConnectivityInfo(self, id):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used for quickly accessing a certain set of connectivity information by id.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myCoreConnectivityInfo.getConnectivityInfo("CoolId")
|
||||
|
||||
**Parameters**
|
||||
|
||||
*id* - The id for the desired connectivity information.
|
||||
|
||||
**Return**
|
||||
|
||||
:code:`AWSIoTPythonSDK.core.greengrass.discovery.models.ConnectivityInfo` object.
|
||||
|
||||
"""
|
||||
return self._connectivity_info_dict.get(id)
|
||||
|
||||
def appendConnectivityInfo(self, connectivityInfo):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used for adding a new set of connectivity information to the list for this Greengrass core. This is used by the
|
||||
SDK internally. No need to call directly from user scripts.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myCoreConnectivityInfo.appendConnectivityInfo(newInfo)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*connectivityInfo* - :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.ConnectivityInfo` object.
|
||||
|
||||
**Returns**
|
||||
|
||||
None
|
||||
|
||||
"""
|
||||
self._connectivity_info_dict[connectivityInfo.id] = connectivityInfo
|
||||
|
||||
|
||||
class GroupConnectivityInfo(object):
|
||||
"""
|
||||
|
||||
Class that stores the connectivity information for a specific Greengrass group.
|
||||
This is the data model for easy access to the discovery information from the discovery request function call. No
|
||||
need to call directly from user scripts.
|
||||
|
||||
"""
|
||||
def __init__(self, groupId):
|
||||
self._group_id = groupId
|
||||
self._core_connectivity_info_dict = dict()
|
||||
self._ca_list = list()
|
||||
|
||||
@property
|
||||
def groupId(self):
|
||||
"""
|
||||
|
||||
Id for this Greengrass group.
|
||||
|
||||
"""
|
||||
return self._group_id
|
||||
|
||||
@property
|
||||
def coreConnectivityInfoList(self):
|
||||
"""
|
||||
|
||||
A list of Greengrass cores
|
||||
(:code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo` object) that belong to this
|
||||
Greengrass group.
|
||||
|
||||
"""
|
||||
return list(self._core_connectivity_info_dict.values())
|
||||
|
||||
@property
|
||||
def caList(self):
|
||||
"""
|
||||
|
||||
A list of CA content strings for this Greengrass group.
|
||||
|
||||
"""
|
||||
return self._ca_list
|
||||
|
||||
def getCoreConnectivityInfo(self, coreThingArn):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to retrieve the corresponding :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo`
|
||||
object by core thing arn.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myGroupConnectivityInfo.getCoreConnectivityInfo("YourOwnArnString")
|
||||
|
||||
**Parameters**
|
||||
|
||||
coreThingArn - Thing arn for the desired Greengrass core.
|
||||
|
||||
**Returns**
|
||||
|
||||
:code:`AWSIoTPythonSDK.core.greengrass.discovery.CoreConnectivityInfo` object.
|
||||
|
||||
"""
|
||||
return self._core_connectivity_info_dict.get(coreThingArn)
|
||||
|
||||
def appendCoreConnectivityInfo(self, coreConnectivityInfo):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to append new core connectivity information to this group connectivity information. This is used by the
|
||||
SDK internally. No need to call directly from user scripts.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myGroupConnectivityInfo.appendCoreConnectivityInfo(newCoreConnectivityInfo)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*coreConnectivityInfo* - :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo` object.
|
||||
|
||||
**Returns**
|
||||
|
||||
None
|
||||
|
||||
"""
|
||||
self._core_connectivity_info_dict[coreConnectivityInfo.coreThingArn] = coreConnectivityInfo
|
||||
|
||||
def appendCa(self, ca):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to append new CA content string to this group connectivity information. This is used by the SDK internally.
|
||||
No need to call directly from user scripts.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myGroupConnectivityInfo.appendCa("CaContentString")
|
||||
|
||||
**Parameters**
|
||||
|
||||
*ca* - Group CA content string.
|
||||
|
||||
**Returns**
|
||||
|
||||
None
|
||||
|
||||
"""
|
||||
self._ca_list.append(ca)
|
||||
|
||||
|
||||
class DiscoveryInfo(object):
|
||||
"""
|
||||
|
||||
Class that stores the discovery information coming back from the discovery request.
|
||||
This is the data model for easy access to the discovery information from the discovery request function call. No
|
||||
need to call directly from user scripts.
|
||||
|
||||
"""
|
||||
def __init__(self, rawJson):
|
||||
self._raw_json = rawJson
|
||||
|
||||
@property
|
||||
def rawJson(self):
|
||||
"""
|
||||
|
||||
JSON response string that contains the discovery information. This is reserved in case users want to do
|
||||
some process by themselves.
|
||||
|
||||
"""
|
||||
return self._raw_json
|
||||
|
||||
def getAllCores(self):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to retrieve the list of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivityInfo`
|
||||
object for this discovery information. The retrieved cores could be from different Greengrass groups. This is
|
||||
designed for uses who want to iterate through all available cores at the same time, regardless of which group
|
||||
those cores are in.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myDiscoveryInfo.getAllCores()
|
||||
|
||||
**Parameters**
|
||||
|
||||
None
|
||||
|
||||
**Returns**
|
||||
|
||||
List of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.CoreConnectivtyInfo` object.
|
||||
|
||||
"""
|
||||
groups_list = self.getAllGroups()
|
||||
core_list = list()
|
||||
|
||||
for group in groups_list:
|
||||
core_list.extend(group.coreConnectivityInfoList)
|
||||
|
||||
return core_list
|
||||
|
||||
def getAllCas(self):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to retrieve the list of :code:`(groupId, caContent)` pair for this discovery information. The retrieved
|
||||
pairs could be from different Greengrass groups. This is designed for users who want to iterate through all
|
||||
available cores/groups/CAs at the same time, regardless of which group those CAs belong to.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myDiscoveryInfo.getAllCas()
|
||||
|
||||
**Parameters**
|
||||
|
||||
None
|
||||
|
||||
**Returns**
|
||||
|
||||
List of :code:`(groupId, caContent)` string pair, where :code:`caContent` is the CA content string and
|
||||
:code:`groupId` is the group id that this CA belongs to.
|
||||
|
||||
"""
|
||||
group_list = self.getAllGroups()
|
||||
ca_list = list()
|
||||
|
||||
for group in group_list:
|
||||
for ca in group.caList:
|
||||
ca_list.append((group.groupId, ca))
|
||||
|
||||
return ca_list
|
||||
|
||||
def getAllGroups(self):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to retrieve the list of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.GroupConnectivityInfo`
|
||||
object for this discovery information. This is designed for users who want to iterate through all available
|
||||
groups that this Greengrass aware device (GGAD) belongs to.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myDiscoveryInfo.getAllGroups()
|
||||
|
||||
**Parameters**
|
||||
|
||||
None
|
||||
|
||||
**Returns**
|
||||
|
||||
List of :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.GroupConnectivityInfo` object.
|
||||
|
||||
"""
|
||||
groups_dict = self.toObjectAtGroupLevel()
|
||||
return list(groups_dict.values())
|
||||
|
||||
def toObjectAtGroupLevel(self):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to get a dictionary of Greengrass group discovery information, with group id string as key and the
|
||||
corresponding :code:`AWSIoTPythonSDK.core.greengrass.discovery.models.GroupConnectivityInfo` object as the
|
||||
value. This is designed for users who know exactly which group, which core and which set of connectivity info
|
||||
they want to use for the Greengrass aware device to connect.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Get to the targeted connectivity information for a specific core in a specific group
|
||||
groupLevelDiscoveryInfoObj = myDiscoveryInfo.toObjectAtGroupLevel()
|
||||
groupConnectivityInfoObj = groupLevelDiscoveryInfoObj.toObjectAtGroupLevel("IKnowMyGroupId")
|
||||
coreConnectivityInfoObj = groupConnectivityInfoObj.getCoreConnectivityInfo("IKnowMyCoreThingArn")
|
||||
connectivityInfo = coreConnectivityInfoObj.getConnectivityInfo("IKnowMyConnectivityInfoSetId")
|
||||
# Now retrieve the detailed information
|
||||
caList = groupConnectivityInfoObj.caList
|
||||
host = connectivityInfo.host
|
||||
port = connectivityInfo.port
|
||||
metadata = connectivityInfo.metadata
|
||||
# Actual connecting logic follows...
|
||||
|
||||
"""
|
||||
groups_object = json.loads(self._raw_json)
|
||||
groups_dict = dict()
|
||||
|
||||
for group_object in groups_object[KEY_GROUP_LIST]:
|
||||
group_info = self._decode_group_info(group_object)
|
||||
groups_dict[group_info.groupId] = group_info
|
||||
|
||||
return groups_dict
|
||||
|
||||
def _decode_group_info(self, group_object):
|
||||
group_id = group_object[KEY_GROUP_ID]
|
||||
group_info = GroupConnectivityInfo(group_id)
|
||||
|
||||
for core in group_object[KEY_CORE_LIST]:
|
||||
core_info = self._decode_core_info(core, group_id)
|
||||
group_info.appendCoreConnectivityInfo(core_info)
|
||||
|
||||
for ca in group_object[KEY_CA_LIST]:
|
||||
group_info.appendCa(ca)
|
||||
|
||||
return group_info
|
||||
|
||||
def _decode_core_info(self, core_object, group_id):
|
||||
core_info = CoreConnectivityInfo(core_object[KEY_CORE_ARN], group_id)
|
||||
|
||||
for connectivity_info_object in core_object[KEY_CONNECTIVITY_INFO_LIST]:
|
||||
connectivity_info = ConnectivityInfo(connectivity_info_object[KEY_CONNECTIVITY_INFO_ID],
|
||||
connectivity_info_object[KEY_HOST_ADDRESS],
|
||||
connectivity_info_object[KEY_PORT_NUMBER],
|
||||
connectivity_info_object.get(KEY_METADATA,''))
|
||||
core_info.appendConnectivityInfo(connectivity_info)
|
||||
|
||||
return core_info
|
||||
@@ -0,0 +1,426 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryInvalidRequestException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryUnauthorizedException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryDataNotFoundException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryThrottlingException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryTimeoutException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import DiscoveryFailure
|
||||
from AWSIoTPythonSDK.core.greengrass.discovery.models import DiscoveryInfo
|
||||
from AWSIoTPythonSDK.core.protocol.connection.alpn import SSLContextBuilder
|
||||
import re
|
||||
import sys
|
||||
import ssl
|
||||
import time
|
||||
import errno
|
||||
import logging
|
||||
import socket
|
||||
import platform
|
||||
if platform.system() == 'Windows':
|
||||
EAGAIN = errno.WSAEWOULDBLOCK
|
||||
else:
|
||||
EAGAIN = errno.EAGAIN
|
||||
|
||||
|
||||
class DiscoveryInfoProvider(object):
|
||||
|
||||
REQUEST_TYPE_PREFIX = "GET "
|
||||
PAYLOAD_PREFIX = "/greengrass/discover/thing/"
|
||||
PAYLOAD_SUFFIX = " HTTP/1.1\r\n" # Space in the front
|
||||
HOST_PREFIX = "Host: "
|
||||
HOST_SUFFIX = "\r\n\r\n"
|
||||
HTTP_PROTOCOL = r"HTTP/1.1 "
|
||||
CONTENT_LENGTH = r"content-length: "
|
||||
CONTENT_LENGTH_PATTERN = CONTENT_LENGTH + r"([0-9]+)\r\n"
|
||||
HTTP_RESPONSE_CODE_PATTERN = HTTP_PROTOCOL + r"([0-9]+) "
|
||||
|
||||
HTTP_SC_200 = "200"
|
||||
HTTP_SC_400 = "400"
|
||||
HTTP_SC_401 = "401"
|
||||
HTTP_SC_404 = "404"
|
||||
HTTP_SC_429 = "429"
|
||||
|
||||
LOW_LEVEL_RC_COMPLETE = 0
|
||||
LOW_LEVEL_RC_TIMEOUT = -1
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, caPath="", certPath="", keyPath="", host="", port=8443, timeoutSec=120):
|
||||
"""
|
||||
|
||||
The class that provides functionality to perform a Greengrass discovery process to the cloud.
|
||||
|
||||
Users can perform Greengrass discovery process for a specific Greengrass aware device to retrieve
|
||||
connectivity/identity information of Greengrass cores within the same group.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
from AWSIoTPythonSDK.core.greengrass.discovery.providers import DiscoveryInfoProvider
|
||||
|
||||
# Create a discovery information provider
|
||||
myDiscoveryInfoProvider = DiscoveryInfoProvider()
|
||||
# Create a discovery information provider with custom configuration
|
||||
myDiscoveryInfoProvider = DiscoveryInfoProvider(caPath=myCAPath, certPath=myCertPath, keyPath=myKeyPath, host=myHost, timeoutSec=myTimeoutSec)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*caPath* - Path to read the root CA file.
|
||||
|
||||
*certPath* - Path to read the certificate file.
|
||||
|
||||
*keyPath* - Path to read the private key file.
|
||||
|
||||
*host* - String that denotes the host name of the user-specific AWS IoT endpoint.
|
||||
|
||||
*port* - Integer that denotes the port number to connect to. For discovery purpose, it is 8443 by default.
|
||||
|
||||
*timeoutSec* - Time out configuration in seconds to consider a discovery request sending/response waiting has
|
||||
been timed out.
|
||||
|
||||
**Returns**
|
||||
|
||||
AWSIoTPythonSDK.core.greengrass.discovery.providers.DiscoveryInfoProvider object
|
||||
|
||||
"""
|
||||
self._ca_path = caPath
|
||||
self._cert_path = certPath
|
||||
self._key_path = keyPath
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._timeout_sec = timeoutSec
|
||||
self._expected_exception_map = {
|
||||
self.HTTP_SC_400 : DiscoveryInvalidRequestException(),
|
||||
self.HTTP_SC_401 : DiscoveryUnauthorizedException(),
|
||||
self.HTTP_SC_404 : DiscoveryDataNotFoundException(),
|
||||
self.HTTP_SC_429 : DiscoveryThrottlingException()
|
||||
}
|
||||
|
||||
def configureEndpoint(self, host, port=8443):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to configure the host address and port number for the discovery request to hit. Should be called before
|
||||
the discovery request happens.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Using default port configuration, 8443
|
||||
myDiscoveryInfoProvider.configureEndpoint(host="prefix.iot.us-east-1.amazonaws.com")
|
||||
# Customize port configuration
|
||||
myDiscoveryInfoProvider.configureEndpoint(host="prefix.iot.us-east-1.amazonaws.com", port=8888)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*host* - String that denotes the host name of the user-specific AWS IoT endpoint.
|
||||
|
||||
*port* - Integer that denotes the port number to connect to. For discovery purpose, it is 8443 by default.
|
||||
|
||||
**Returns**
|
||||
|
||||
None
|
||||
|
||||
"""
|
||||
self._host = host
|
||||
self._port = port
|
||||
|
||||
def configureCredentials(self, caPath, certPath, keyPath):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to configure the credentials for discovery request. Should be called before the discovery request happens.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myDiscoveryInfoProvider.configureCredentials("my/ca/path", "my/cert/path", "my/key/path")
|
||||
|
||||
**Parameters**
|
||||
|
||||
*caPath* - Path to read the root CA file.
|
||||
|
||||
*certPath* - Path to read the certificate file.
|
||||
|
||||
*keyPath* - Path to read the private key file.
|
||||
|
||||
**Returns**
|
||||
|
||||
None
|
||||
|
||||
"""
|
||||
self._ca_path = caPath
|
||||
self._cert_path = certPath
|
||||
self._key_path = keyPath
|
||||
|
||||
def configureTimeout(self, timeoutSec):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Used to configure the time out in seconds for discovery request sending/response waiting. Should be called before
|
||||
the discovery request happens.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Configure the time out for discovery to be 10 seconds
|
||||
myDiscoveryInfoProvider.configureTimeout(10)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*timeoutSec* - Time out configuration in seconds to consider a discovery request sending/response waiting has
|
||||
been timed out.
|
||||
|
||||
**Returns**
|
||||
|
||||
None
|
||||
|
||||
"""
|
||||
self._timeout_sec = timeoutSec
|
||||
|
||||
def discover(self, thingName):
|
||||
"""
|
||||
|
||||
**Description**
|
||||
|
||||
Perform the discovery request for the given Greengrass aware device thing name.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
myDiscoveryInfoProvider.discover(thingName="myGGAD")
|
||||
|
||||
**Parameters**
|
||||
|
||||
*thingName* - Greengrass aware device thing name.
|
||||
|
||||
**Returns**
|
||||
|
||||
:code:`AWSIoTPythonSDK.core.greengrass.discovery.models.DiscoveryInfo` object.
|
||||
|
||||
"""
|
||||
self._logger.info("Starting discover request...")
|
||||
self._logger.info("Endpoint: " + self._host + ":" + str(self._port))
|
||||
self._logger.info("Target thing: " + thingName)
|
||||
sock = self._create_tcp_connection()
|
||||
ssl_sock = self._create_ssl_connection(sock)
|
||||
self._raise_on_timeout(self._send_discovery_request(ssl_sock, thingName))
|
||||
status_code, response_body = self._receive_discovery_response(ssl_sock)
|
||||
|
||||
return self._raise_if_not_200(status_code, response_body)
|
||||
|
||||
def _create_tcp_connection(self):
|
||||
self._logger.debug("Creating tcp connection...")
|
||||
try:
|
||||
if (sys.version_info[0] == 2 and sys.version_info[1] < 7) or (sys.version_info[0] == 3 and sys.version_info[1] < 2):
|
||||
sock = socket.create_connection((self._host, self._port))
|
||||
else:
|
||||
sock = socket.create_connection((self._host, self._port), source_address=("", 0))
|
||||
return sock
|
||||
except socket.error as err:
|
||||
if err.errno != errno.EINPROGRESS and err.errno != errno.EWOULDBLOCK and err.errno != EAGAIN:
|
||||
raise
|
||||
self._logger.debug("Created tcp connection.")
|
||||
|
||||
def _create_ssl_connection(self, sock):
|
||||
self._logger.debug("Creating ssl connection...")
|
||||
|
||||
ssl_protocol_version = ssl.PROTOCOL_SSLv23
|
||||
|
||||
if self._port == 443:
|
||||
ssl_context = SSLContextBuilder()\
|
||||
.with_ca_certs(self._ca_path)\
|
||||
.with_cert_key_pair(self._cert_path, self._key_path)\
|
||||
.with_cert_reqs(ssl.CERT_REQUIRED)\
|
||||
.with_check_hostname(True)\
|
||||
.with_ciphers(None)\
|
||||
.with_alpn_protocols(['x-amzn-http-ca'])\
|
||||
.build()
|
||||
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=self._host, do_handshake_on_connect=False)
|
||||
ssl_sock.do_handshake()
|
||||
else:
|
||||
ssl_sock = ssl.wrap_socket(sock,
|
||||
certfile=self._cert_path,
|
||||
keyfile=self._key_path,
|
||||
ca_certs=self._ca_path,
|
||||
cert_reqs=ssl.CERT_REQUIRED,
|
||||
ssl_version=ssl_protocol_version)
|
||||
|
||||
self._logger.debug("Matching host name...")
|
||||
if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and sys.version_info[1] < 2):
|
||||
self._tls_match_hostname(ssl_sock)
|
||||
else:
|
||||
ssl.match_hostname(ssl_sock.getpeercert(), self._host)
|
||||
|
||||
return ssl_sock
|
||||
|
||||
def _tls_match_hostname(self, ssl_sock):
|
||||
try:
|
||||
cert = ssl_sock.getpeercert()
|
||||
except AttributeError:
|
||||
# the getpeercert can throw Attribute error: object has no attribute 'peer_certificate'
|
||||
# Don't let that crash the whole client. See also: http://bugs.python.org/issue13721
|
||||
raise ssl.SSLError('Not connected')
|
||||
|
||||
san = cert.get('subjectAltName')
|
||||
if san:
|
||||
have_san_dns = False
|
||||
for (key, value) in san:
|
||||
if key == 'DNS':
|
||||
have_san_dns = True
|
||||
if self._host_matches_cert(self._host.lower(), value.lower()) == True:
|
||||
return
|
||||
if key == 'IP Address':
|
||||
have_san_dns = True
|
||||
if value.lower() == self._host.lower():
|
||||
return
|
||||
|
||||
if have_san_dns:
|
||||
# Only check subject if subjectAltName dns not found.
|
||||
raise ssl.SSLError('Certificate subject does not match remote hostname.')
|
||||
subject = cert.get('subject')
|
||||
if subject:
|
||||
for ((key, value),) in subject:
|
||||
if key == 'commonName':
|
||||
if self._host_matches_cert(self._host.lower(), value.lower()) == True:
|
||||
return
|
||||
|
||||
raise ssl.SSLError('Certificate subject does not match remote hostname.')
|
||||
|
||||
def _host_matches_cert(self, host, cert_host):
|
||||
if cert_host[0:2] == "*.":
|
||||
if cert_host.count("*") != 1:
|
||||
return False
|
||||
|
||||
host_match = host.split(".", 1)[1]
|
||||
cert_match = cert_host.split(".", 1)[1]
|
||||
if host_match == cert_match:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
if host == cert_host:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _send_discovery_request(self, ssl_sock, thing_name):
|
||||
request = self.REQUEST_TYPE_PREFIX + \
|
||||
self.PAYLOAD_PREFIX + \
|
||||
thing_name + \
|
||||
self.PAYLOAD_SUFFIX + \
|
||||
self.HOST_PREFIX + \
|
||||
self._host + ":" + str(self._port) + \
|
||||
self.HOST_SUFFIX
|
||||
self._logger.debug("Sending discover request: " + request)
|
||||
|
||||
start_time = time.time()
|
||||
desired_length_to_write = len(request)
|
||||
actual_length_written = 0
|
||||
while True:
|
||||
try:
|
||||
length_written = ssl_sock.write(request.encode("utf-8"))
|
||||
actual_length_written += length_written
|
||||
except socket.error as err:
|
||||
if err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE:
|
||||
pass
|
||||
if actual_length_written == desired_length_to_write:
|
||||
return self.LOW_LEVEL_RC_COMPLETE
|
||||
if start_time + self._timeout_sec < time.time():
|
||||
return self.LOW_LEVEL_RC_TIMEOUT
|
||||
|
||||
def _receive_discovery_response(self, ssl_sock):
|
||||
self._logger.debug("Receiving discover response header...")
|
||||
rc1, response_header = self._receive_until(ssl_sock, self._got_two_crlfs)
|
||||
status_code, body_length = self._handle_discovery_response_header(rc1, response_header.decode("utf-8"))
|
||||
|
||||
self._logger.debug("Receiving discover response body...")
|
||||
rc2, response_body = self._receive_until(ssl_sock, self._got_enough_bytes, body_length)
|
||||
response_body = self._handle_discovery_response_body(rc2, response_body.decode("utf-8"))
|
||||
|
||||
return status_code, response_body
|
||||
|
||||
def _receive_until(self, ssl_sock, criteria_function, extra_data=None):
|
||||
start_time = time.time()
|
||||
response = bytearray()
|
||||
number_bytes_read = 0
|
||||
while True: # Python does not have do-while
|
||||
try:
|
||||
response.append(self._convert_to_int_py3(ssl_sock.read(1)))
|
||||
number_bytes_read += 1
|
||||
except socket.error as err:
|
||||
if err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE:
|
||||
pass
|
||||
|
||||
if criteria_function((number_bytes_read, response, extra_data)):
|
||||
return self.LOW_LEVEL_RC_COMPLETE, response
|
||||
if start_time + self._timeout_sec < time.time():
|
||||
return self.LOW_LEVEL_RC_TIMEOUT, response
|
||||
|
||||
def _convert_to_int_py3(self, input_char):
|
||||
try:
|
||||
return ord(input_char)
|
||||
except:
|
||||
return input_char
|
||||
|
||||
def _got_enough_bytes(self, data):
|
||||
number_bytes_read, response, target_length = data
|
||||
return number_bytes_read == int(target_length)
|
||||
|
||||
def _got_two_crlfs(self, data):
|
||||
number_bytes_read, response, extra_data_unused = data
|
||||
number_of_crlf = 2
|
||||
has_enough_bytes = number_bytes_read > number_of_crlf * 2 - 1
|
||||
if has_enough_bytes:
|
||||
end_of_received = response[number_bytes_read - number_of_crlf * 2 : number_bytes_read]
|
||||
expected_end_of_response = b"\r\n" * number_of_crlf
|
||||
return end_of_received == expected_end_of_response
|
||||
else:
|
||||
return False
|
||||
|
||||
def _handle_discovery_response_header(self, rc, response):
|
||||
self._raise_on_timeout(rc)
|
||||
http_status_code_matcher = re.compile(self.HTTP_RESPONSE_CODE_PATTERN)
|
||||
http_status_code_matched_groups = http_status_code_matcher.match(response)
|
||||
content_length_matcher = re.compile(self.CONTENT_LENGTH_PATTERN)
|
||||
content_length_matched_groups = content_length_matcher.search(response)
|
||||
return http_status_code_matched_groups.group(1), content_length_matched_groups.group(1)
|
||||
|
||||
def _handle_discovery_response_body(self, rc, response):
|
||||
self._raise_on_timeout(rc)
|
||||
return response
|
||||
|
||||
def _raise_on_timeout(self, rc):
|
||||
if rc == self.LOW_LEVEL_RC_TIMEOUT:
|
||||
raise DiscoveryTimeoutException()
|
||||
|
||||
def _raise_if_not_200(self, status_code, response_body): # response_body here is str in Py3
|
||||
if status_code != self.HTTP_SC_200:
|
||||
expected_exception = self._expected_exception_map.get(status_code)
|
||||
if expected_exception:
|
||||
raise expected_exception
|
||||
else:
|
||||
raise DiscoveryFailure(response_body)
|
||||
return DiscoveryInfo(response_body)
|
||||
@@ -0,0 +1,156 @@
|
||||
# /*
|
||||
# * Copyright 2010-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import json
|
||||
|
||||
_BASE_THINGS_TOPIC = "$aws/things/"
|
||||
_NOTIFY_OPERATION = "notify"
|
||||
_NOTIFY_NEXT_OPERATION = "notify-next"
|
||||
_GET_OPERATION = "get"
|
||||
_START_NEXT_OPERATION = "start-next"
|
||||
_WILDCARD_OPERATION = "+"
|
||||
_UPDATE_OPERATION = "update"
|
||||
_ACCEPTED_REPLY = "accepted"
|
||||
_REJECTED_REPLY = "rejected"
|
||||
_WILDCARD_REPLY = "#"
|
||||
|
||||
#Members of this enum are tuples
|
||||
_JOB_ID_REQUIRED_INDEX = 1
|
||||
_JOB_OPERATION_INDEX = 2
|
||||
|
||||
_STATUS_KEY = 'status'
|
||||
_STATUS_DETAILS_KEY = 'statusDetails'
|
||||
_EXPECTED_VERSION_KEY = 'expectedVersion'
|
||||
_EXEXCUTION_NUMBER_KEY = 'executionNumber'
|
||||
_INCLUDE_JOB_EXECUTION_STATE_KEY = 'includeJobExecutionState'
|
||||
_INCLUDE_JOB_DOCUMENT_KEY = 'includeJobDocument'
|
||||
_CLIENT_TOKEN_KEY = 'clientToken'
|
||||
_STEP_TIMEOUT_IN_MINUTES_KEY = 'stepTimeoutInMinutes'
|
||||
|
||||
#The type of job topic.
|
||||
class jobExecutionTopicType(object):
|
||||
JOB_UNRECOGNIZED_TOPIC = (0, False, '')
|
||||
JOB_GET_PENDING_TOPIC = (1, False, _GET_OPERATION)
|
||||
JOB_START_NEXT_TOPIC = (2, False, _START_NEXT_OPERATION)
|
||||
JOB_DESCRIBE_TOPIC = (3, True, _GET_OPERATION)
|
||||
JOB_UPDATE_TOPIC = (4, True, _UPDATE_OPERATION)
|
||||
JOB_NOTIFY_TOPIC = (5, False, _NOTIFY_OPERATION)
|
||||
JOB_NOTIFY_NEXT_TOPIC = (6, False, _NOTIFY_NEXT_OPERATION)
|
||||
JOB_WILDCARD_TOPIC = (7, False, _WILDCARD_OPERATION)
|
||||
|
||||
#Members of this enum are tuples
|
||||
_JOB_SUFFIX_INDEX = 1
|
||||
#The type of reply topic, or #JOB_REQUEST_TYPE for topics that are not replies.
|
||||
class jobExecutionTopicReplyType(object):
|
||||
JOB_UNRECOGNIZED_TOPIC_TYPE = (0, '')
|
||||
JOB_REQUEST_TYPE = (1, '')
|
||||
JOB_ACCEPTED_REPLY_TYPE = (2, '/' + _ACCEPTED_REPLY)
|
||||
JOB_REJECTED_REPLY_TYPE = (3, '/' + _REJECTED_REPLY)
|
||||
JOB_WILDCARD_REPLY_TYPE = (4, '/' + _WILDCARD_REPLY)
|
||||
|
||||
_JOB_STATUS_INDEX = 1
|
||||
class jobExecutionStatus(object):
|
||||
JOB_EXECUTION_STATUS_NOT_SET = (0, None)
|
||||
JOB_EXECUTION_QUEUED = (1, 'QUEUED')
|
||||
JOB_EXECUTION_IN_PROGRESS = (2, 'IN_PROGRESS')
|
||||
JOB_EXECUTION_FAILED = (3, 'FAILED')
|
||||
JOB_EXECUTION_SUCCEEDED = (4, 'SUCCEEDED')
|
||||
JOB_EXECUTION_CANCELED = (5, 'CANCELED')
|
||||
JOB_EXECUTION_REJECTED = (6, 'REJECTED')
|
||||
JOB_EXECUTION_UNKNOWN_STATUS = (99, None)
|
||||
|
||||
def _getExecutionStatus(jobStatus):
|
||||
try:
|
||||
return jobStatus[_JOB_STATUS_INDEX]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def _isWithoutJobIdTopicType(srcJobExecTopicType):
|
||||
return (srcJobExecTopicType == jobExecutionTopicType.JOB_GET_PENDING_TOPIC or srcJobExecTopicType == jobExecutionTopicType.JOB_START_NEXT_TOPIC
|
||||
or srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_TOPIC or srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC)
|
||||
|
||||
class thingJobManager:
|
||||
def __init__(self, thingName, clientToken = None):
|
||||
self._thingName = thingName
|
||||
self._clientToken = clientToken
|
||||
|
||||
def getJobTopic(self, srcJobExecTopicType, srcJobExecTopicReplyType=jobExecutionTopicReplyType.JOB_REQUEST_TYPE, jobId=None):
|
||||
if self._thingName is None:
|
||||
return None
|
||||
|
||||
#Verify topics that only support request type, actually have request type specified for reply
|
||||
if (srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_TOPIC or srcJobExecTopicType == jobExecutionTopicType.JOB_NOTIFY_NEXT_TOPIC) and srcJobExecTopicReplyType != jobExecutionTopicReplyType.JOB_REQUEST_TYPE:
|
||||
return None
|
||||
|
||||
#Verify topics that explicitly do not want a job ID do not have one specified
|
||||
if (jobId is not None and _isWithoutJobIdTopicType(srcJobExecTopicType)):
|
||||
return None
|
||||
|
||||
#Verify job ID is present if the topic requires one
|
||||
if jobId is None and srcJobExecTopicType[_JOB_ID_REQUIRED_INDEX]:
|
||||
return None
|
||||
|
||||
#Ensure the job operation is a non-empty string
|
||||
if srcJobExecTopicType[_JOB_OPERATION_INDEX] == '':
|
||||
return None
|
||||
|
||||
if srcJobExecTopicType[_JOB_ID_REQUIRED_INDEX]:
|
||||
return '{0}{1}/jobs/{2}/{3}{4}'.format(_BASE_THINGS_TOPIC, self._thingName, str(jobId), srcJobExecTopicType[_JOB_OPERATION_INDEX], srcJobExecTopicReplyType[_JOB_SUFFIX_INDEX])
|
||||
elif srcJobExecTopicType == jobExecutionTopicType.JOB_WILDCARD_TOPIC:
|
||||
return '{0}{1}/jobs/#'.format(_BASE_THINGS_TOPIC, self._thingName)
|
||||
else:
|
||||
return '{0}{1}/jobs/{2}{3}'.format(_BASE_THINGS_TOPIC, self._thingName, srcJobExecTopicType[_JOB_OPERATION_INDEX], srcJobExecTopicReplyType[_JOB_SUFFIX_INDEX])
|
||||
|
||||
def serializeJobExecutionUpdatePayload(self, status, statusDetails=None, expectedVersion=0, executionNumber=0, includeJobExecutionState=False, includeJobDocument=False, stepTimeoutInMinutes=None):
|
||||
executionStatus = _getExecutionStatus(status)
|
||||
if executionStatus is None:
|
||||
return None
|
||||
payload = {_STATUS_KEY: executionStatus}
|
||||
if statusDetails:
|
||||
payload[_STATUS_DETAILS_KEY] = statusDetails
|
||||
if expectedVersion > 0:
|
||||
payload[_EXPECTED_VERSION_KEY] = str(expectedVersion)
|
||||
if executionNumber > 0:
|
||||
payload[_EXEXCUTION_NUMBER_KEY] = str(executionNumber)
|
||||
if includeJobExecutionState:
|
||||
payload[_INCLUDE_JOB_EXECUTION_STATE_KEY] = True
|
||||
if includeJobDocument:
|
||||
payload[_INCLUDE_JOB_DOCUMENT_KEY] = True
|
||||
if self._clientToken is not None:
|
||||
payload[_CLIENT_TOKEN_KEY] = self._clientToken
|
||||
if stepTimeoutInMinutes is not None:
|
||||
payload[_STEP_TIMEOUT_IN_MINUTES_KEY] = stepTimeoutInMinutes
|
||||
return json.dumps(payload)
|
||||
|
||||
def serializeDescribeJobExecutionPayload(self, executionNumber=0, includeJobDocument=True):
|
||||
payload = {_INCLUDE_JOB_DOCUMENT_KEY: includeJobDocument}
|
||||
if executionNumber > 0:
|
||||
payload[_EXEXCUTION_NUMBER_KEY] = executionNumber
|
||||
if self._clientToken is not None:
|
||||
payload[_CLIENT_TOKEN_KEY] = self._clientToken
|
||||
return json.dumps(payload)
|
||||
|
||||
def serializeStartNextPendingJobExecutionPayload(self, statusDetails=None, stepTimeoutInMinutes=None):
|
||||
payload = {}
|
||||
if self._clientToken is not None:
|
||||
payload[_CLIENT_TOKEN_KEY] = self._clientToken
|
||||
if statusDetails is not None:
|
||||
payload[_STATUS_DETAILS_KEY] = statusDetails
|
||||
if stepTimeoutInMinutes is not None:
|
||||
payload[_STEP_TIMEOUT_IN_MINUTES_KEY] = stepTimeoutInMinutes
|
||||
return json.dumps(payload)
|
||||
|
||||
def serializeClientTokenPayload(self):
|
||||
return json.dumps({_CLIENT_TOKEN_KEY: self._clientToken}) if self._clientToken is not None else '{}'
|
||||
@@ -0,0 +1,63 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except:
|
||||
ssl = None
|
||||
|
||||
|
||||
class SSLContextBuilder(object):
|
||||
|
||||
def __init__(self):
|
||||
self.check_supportability()
|
||||
self._ssl_context = ssl.create_default_context()
|
||||
|
||||
def check_supportability(self):
|
||||
if ssl is None:
|
||||
raise RuntimeError("This platform has no SSL/TLS.")
|
||||
if not hasattr(ssl, "SSLContext"):
|
||||
raise NotImplementedError("This platform does not support SSLContext. Python 2.7.10+/3.5+ is required.")
|
||||
if not hasattr(ssl.SSLContext, "set_alpn_protocols"):
|
||||
raise NotImplementedError("This platform does not support ALPN as TLS extensions. Python 2.7.10+/3.5+ is required.")
|
||||
|
||||
def with_ca_certs(self, ca_certs):
|
||||
self._ssl_context.load_verify_locations(ca_certs)
|
||||
return self
|
||||
|
||||
def with_cert_key_pair(self, cert_file, key_file):
|
||||
self._ssl_context.load_cert_chain(cert_file, key_file)
|
||||
return self
|
||||
|
||||
def with_cert_reqs(self, cert_reqs):
|
||||
self._ssl_context.verify_mode = cert_reqs
|
||||
return self
|
||||
|
||||
def with_check_hostname(self, check_hostname):
|
||||
self._ssl_context.check_hostname = check_hostname
|
||||
return self
|
||||
|
||||
def with_ciphers(self, ciphers):
|
||||
if ciphers is not None:
|
||||
self._ssl_context.set_ciphers(ciphers) # set_ciphers() does not allow None input. Use default (do nothing) if None
|
||||
return self
|
||||
|
||||
def with_alpn_protocols(self, alpn_protocols):
|
||||
self._ssl_context.set_alpn_protocols(alpn_protocols)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self._ssl_context
|
||||
@@ -0,0 +1,699 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
# This class implements the progressive backoff logic for auto-reconnect.
|
||||
# It manages the reconnect wait time for the current reconnect, controling
|
||||
# when to increase it and when to reset it.
|
||||
|
||||
|
||||
import re
|
||||
import sys
|
||||
import ssl
|
||||
import errno
|
||||
import struct
|
||||
import socket
|
||||
import base64
|
||||
import time
|
||||
import threading
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import ClientError
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import wssNoKeyInEnvironmentError
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import wssHandShakeError
|
||||
from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC
|
||||
try:
|
||||
from urllib.parse import quote # Python 3+
|
||||
except ImportError:
|
||||
from urllib import quote
|
||||
# INI config file handling
|
||||
try:
|
||||
from configparser import ConfigParser # Python 3+
|
||||
from configparser import NoOptionError
|
||||
from configparser import NoSectionError
|
||||
except ImportError:
|
||||
from ConfigParser import ConfigParser
|
||||
from ConfigParser import NoOptionError
|
||||
from ConfigParser import NoSectionError
|
||||
|
||||
|
||||
class ProgressiveBackOffCore:
|
||||
# Logger
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, srcBaseReconnectTimeSecond=1, srcMaximumReconnectTimeSecond=32, srcMinimumConnectTimeSecond=20):
|
||||
# The base reconnection time in seconds, default 1
|
||||
self._baseReconnectTimeSecond = srcBaseReconnectTimeSecond
|
||||
# The maximum reconnection time in seconds, default 32
|
||||
self._maximumReconnectTimeSecond = srcMaximumReconnectTimeSecond
|
||||
# The minimum time in milliseconds that a connection must be maintained in order to be considered stable
|
||||
# Default 20
|
||||
self._minimumConnectTimeSecond = srcMinimumConnectTimeSecond
|
||||
# Current backOff time in seconds, init to equal to 0
|
||||
self._currentBackoffTimeSecond = 1
|
||||
# Handler for timer
|
||||
self._resetBackoffTimer = None
|
||||
|
||||
# For custom progressiveBackoff timing configuration
|
||||
def configTime(self, srcBaseReconnectTimeSecond, srcMaximumReconnectTimeSecond, srcMinimumConnectTimeSecond):
|
||||
if srcBaseReconnectTimeSecond < 0 or srcMaximumReconnectTimeSecond < 0 or srcMinimumConnectTimeSecond < 0:
|
||||
self._logger.error("init: Negative time configuration detected.")
|
||||
raise ValueError("Negative time configuration detected.")
|
||||
if srcBaseReconnectTimeSecond >= srcMinimumConnectTimeSecond:
|
||||
self._logger.error("init: Min connect time should be bigger than base reconnect time.")
|
||||
raise ValueError("Min connect time should be bigger than base reconnect time.")
|
||||
self._baseReconnectTimeSecond = srcBaseReconnectTimeSecond
|
||||
self._maximumReconnectTimeSecond = srcMaximumReconnectTimeSecond
|
||||
self._minimumConnectTimeSecond = srcMinimumConnectTimeSecond
|
||||
self._currentBackoffTimeSecond = 1
|
||||
|
||||
# Block the reconnect logic for _currentBackoffTimeSecond
|
||||
# Update the currentBackoffTimeSecond for the next reconnect
|
||||
# Cancel the in-waiting timer for resetting backOff time
|
||||
# This should get called only when a disconnect/reconnect happens
|
||||
def backOff(self):
|
||||
self._logger.debug("backOff: current backoff time is: " + str(self._currentBackoffTimeSecond) + " sec.")
|
||||
if self._resetBackoffTimer is not None:
|
||||
# Cancel the timer
|
||||
self._resetBackoffTimer.cancel()
|
||||
# Block the reconnect logic
|
||||
time.sleep(self._currentBackoffTimeSecond)
|
||||
# Update the backoff time
|
||||
if self._currentBackoffTimeSecond == 0:
|
||||
# This is the first attempt to connect, set it to base
|
||||
self._currentBackoffTimeSecond = self._baseReconnectTimeSecond
|
||||
else:
|
||||
# r_cur = min(2^n*r_base, r_max)
|
||||
self._currentBackoffTimeSecond = min(self._maximumReconnectTimeSecond, self._currentBackoffTimeSecond * 2)
|
||||
|
||||
# Start the timer for resetting _currentBackoffTimeSecond
|
||||
# Will be cancelled upon calling backOff
|
||||
def startStableConnectionTimer(self):
|
||||
self._resetBackoffTimer = threading.Timer(self._minimumConnectTimeSecond,
|
||||
self._connectionStableThenResetBackoffTime)
|
||||
self._resetBackoffTimer.start()
|
||||
|
||||
def stopStableConnectionTimer(self):
|
||||
if self._resetBackoffTimer is not None:
|
||||
# Cancel the timer
|
||||
self._resetBackoffTimer.cancel()
|
||||
|
||||
# Timer callback to reset _currentBackoffTimeSecond
|
||||
# If the connection is stable for longer than _minimumConnectTimeSecond,
|
||||
# reset the currentBackoffTimeSecond to _baseReconnectTimeSecond
|
||||
def _connectionStableThenResetBackoffTime(self):
|
||||
self._logger.debug(
|
||||
"stableConnection: Resetting the backoff time to: " + str(self._baseReconnectTimeSecond) + " sec.")
|
||||
self._currentBackoffTimeSecond = self._baseReconnectTimeSecond
|
||||
|
||||
|
||||
class SigV4Core:
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self):
|
||||
self._aws_access_key_id = ""
|
||||
self._aws_secret_access_key = ""
|
||||
self._aws_session_token = ""
|
||||
self._credentialConfigFilePath = "~/.aws/credentials"
|
||||
|
||||
def setIAMCredentials(self, srcAWSAccessKeyID, srcAWSSecretAccessKey, srcAWSSessionToken):
|
||||
self._aws_access_key_id = srcAWSAccessKeyID
|
||||
self._aws_secret_access_key = srcAWSSecretAccessKey
|
||||
self._aws_session_token = srcAWSSessionToken
|
||||
|
||||
def _createAmazonDate(self):
|
||||
# Returned as a unicode string in Py3.x
|
||||
amazonDate = []
|
||||
currentTime = datetime.utcnow()
|
||||
YMDHMS = currentTime.strftime('%Y%m%dT%H%M%SZ')
|
||||
YMD = YMDHMS[0:YMDHMS.index('T')]
|
||||
amazonDate.append(YMD)
|
||||
amazonDate.append(YMDHMS)
|
||||
return amazonDate
|
||||
|
||||
def _sign(self, key, message):
|
||||
# Returned as a utf-8 byte string in Py3.x
|
||||
return hmac.new(key, message.encode('utf-8'), hashlib.sha256).digest()
|
||||
|
||||
def _getSignatureKey(self, key, dateStamp, regionName, serviceName):
|
||||
# Returned as a utf-8 byte string in Py3.x
|
||||
kDate = self._sign(('AWS4' + key).encode('utf-8'), dateStamp)
|
||||
kRegion = self._sign(kDate, regionName)
|
||||
kService = self._sign(kRegion, serviceName)
|
||||
kSigning = self._sign(kService, 'aws4_request')
|
||||
return kSigning
|
||||
|
||||
def _checkIAMCredentials(self):
|
||||
# Check custom config
|
||||
ret = self._checkKeyInCustomConfig()
|
||||
# Check environment variables
|
||||
if not ret:
|
||||
ret = self._checkKeyInEnv()
|
||||
# Check files
|
||||
if not ret:
|
||||
ret = self._checkKeyInFiles()
|
||||
# All credentials returned as unicode strings in Py3.x
|
||||
return ret
|
||||
|
||||
def _checkKeyInEnv(self):
|
||||
ret = dict()
|
||||
self._aws_access_key_id = os.environ.get('AWS_ACCESS_KEY_ID')
|
||||
self._aws_secret_access_key = os.environ.get('AWS_SECRET_ACCESS_KEY')
|
||||
self._aws_session_token = os.environ.get('AWS_SESSION_TOKEN')
|
||||
if self._aws_access_key_id is not None and self._aws_secret_access_key is not None:
|
||||
ret["aws_access_key_id"] = self._aws_access_key_id
|
||||
ret["aws_secret_access_key"] = self._aws_secret_access_key
|
||||
# We do not necessarily need session token...
|
||||
if self._aws_session_token is not None:
|
||||
ret["aws_session_token"] = self._aws_session_token
|
||||
self._logger.debug("IAM credentials from env var.")
|
||||
return ret
|
||||
|
||||
def _checkKeyInINIDefault(self, srcConfigParser, sectionName):
|
||||
ret = dict()
|
||||
# Check aws_access_key_id and aws_secret_access_key
|
||||
try:
|
||||
ret["aws_access_key_id"] = srcConfigParser.get(sectionName, "aws_access_key_id")
|
||||
ret["aws_secret_access_key"] = srcConfigParser.get(sectionName, "aws_secret_access_key")
|
||||
except NoOptionError:
|
||||
self._logger.warn("Cannot find IAM keyID/secretKey in credential file.")
|
||||
# We do not continue searching if we cannot even get IAM id/secret right
|
||||
if len(ret) == 2:
|
||||
# Check aws_session_token, optional
|
||||
try:
|
||||
ret["aws_session_token"] = srcConfigParser.get(sectionName, "aws_session_token")
|
||||
except NoOptionError:
|
||||
self._logger.debug("No AWS Session Token found.")
|
||||
return ret
|
||||
|
||||
def _checkKeyInFiles(self):
|
||||
credentialFile = None
|
||||
credentialConfig = None
|
||||
ret = dict()
|
||||
# Should be compatible with aws cli default credential configuration
|
||||
# *NIX/Windows
|
||||
try:
|
||||
# See if we get the file
|
||||
credentialConfig = ConfigParser()
|
||||
credentialFilePath = os.path.expanduser(self._credentialConfigFilePath) # Is it compatible with windows? \/
|
||||
credentialConfig.read(credentialFilePath)
|
||||
# Now we have the file, start looking for credentials...
|
||||
# 'default' section
|
||||
ret = self._checkKeyInINIDefault(credentialConfig, "default")
|
||||
if not ret:
|
||||
# 'DEFAULT' section
|
||||
ret = self._checkKeyInINIDefault(credentialConfig, "DEFAULT")
|
||||
self._logger.debug("IAM credentials from file.")
|
||||
except IOError:
|
||||
self._logger.debug("No IAM credential configuration file in " + credentialFilePath)
|
||||
except NoSectionError:
|
||||
self._logger.error("Cannot find IAM 'default' section.")
|
||||
return ret
|
||||
|
||||
def _checkKeyInCustomConfig(self):
|
||||
ret = dict()
|
||||
if self._aws_access_key_id != "" and self._aws_secret_access_key != "":
|
||||
ret["aws_access_key_id"] = self._aws_access_key_id
|
||||
ret["aws_secret_access_key"] = self._aws_secret_access_key
|
||||
# We do not necessarily need session token...
|
||||
if self._aws_session_token != "":
|
||||
ret["aws_session_token"] = self._aws_session_token
|
||||
self._logger.debug("IAM credentials from custom config.")
|
||||
return ret
|
||||
|
||||
def createWebsocketEndpoint(self, host, port, region, method, awsServiceName, path):
|
||||
# Return the endpoint as unicode string in 3.x
|
||||
# Gather all the facts
|
||||
amazonDate = self._createAmazonDate()
|
||||
amazonDateSimple = amazonDate[0] # Unicode in 3.x
|
||||
amazonDateComplex = amazonDate[1] # Unicode in 3.x
|
||||
allKeys = self._checkIAMCredentials() # Unicode in 3.x
|
||||
if not self._hasCredentialsNecessaryForWebsocket(allKeys):
|
||||
raise wssNoKeyInEnvironmentError()
|
||||
else:
|
||||
# Because of self._hasCredentialsNecessaryForWebsocket(...), keyID and secretKey should not be None from here
|
||||
keyID = allKeys["aws_access_key_id"]
|
||||
secretKey = allKeys["aws_secret_access_key"]
|
||||
# amazonDateSimple and amazonDateComplex are guaranteed not to be None
|
||||
queryParameters = "X-Amz-Algorithm=AWS4-HMAC-SHA256" + \
|
||||
"&X-Amz-Credential=" + keyID + "%2F" + amazonDateSimple + "%2F" + region + "%2F" + awsServiceName + "%2Faws4_request" + \
|
||||
"&X-Amz-Date=" + amazonDateComplex + \
|
||||
"&X-Amz-Expires=86400" + \
|
||||
"&X-Amz-SignedHeaders=host" # Unicode in 3.x
|
||||
hashedPayload = hashlib.sha256(str("").encode('utf-8')).hexdigest() # Unicode in 3.x
|
||||
# Create the string to sign
|
||||
signedHeaders = "host"
|
||||
canonicalHeaders = "host:" + host + "\n"
|
||||
canonicalRequest = method + "\n" + path + "\n" + queryParameters + "\n" + canonicalHeaders + "\n" + signedHeaders + "\n" + hashedPayload # Unicode in 3.x
|
||||
hashedCanonicalRequest = hashlib.sha256(str(canonicalRequest).encode('utf-8')).hexdigest() # Unicoede in 3.x
|
||||
stringToSign = "AWS4-HMAC-SHA256\n" + amazonDateComplex + "\n" + amazonDateSimple + "/" + region + "/" + awsServiceName + "/aws4_request\n" + hashedCanonicalRequest # Unicode in 3.x
|
||||
# Sign it
|
||||
signingKey = self._getSignatureKey(secretKey, amazonDateSimple, region, awsServiceName)
|
||||
signature = hmac.new(signingKey, (stringToSign).encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
# generate url
|
||||
url = "wss://" + host + ":" + str(port) + path + '?' + queryParameters + "&X-Amz-Signature=" + signature
|
||||
# See if we have STS token, if we do, add it
|
||||
awsSessionTokenCandidate = allKeys.get("aws_session_token")
|
||||
if awsSessionTokenCandidate is not None and len(awsSessionTokenCandidate) != 0:
|
||||
aws_session_token = allKeys["aws_session_token"]
|
||||
url += "&X-Amz-Security-Token=" + quote(aws_session_token.encode("utf-8")) # Unicode in 3.x
|
||||
self._logger.debug("createWebsocketEndpoint: Websocket URL: " + url)
|
||||
return url
|
||||
|
||||
def _hasCredentialsNecessaryForWebsocket(self, allKeys):
|
||||
awsAccessKeyIdCandidate = allKeys.get("aws_access_key_id")
|
||||
awsSecretAccessKeyCandidate = allKeys.get("aws_secret_access_key")
|
||||
# None value is NOT considered as valid entries
|
||||
validEntries = awsAccessKeyIdCandidate is not None and awsAccessKeyIdCandidate is not None
|
||||
if validEntries:
|
||||
# Empty value is NOT considered as valid entries
|
||||
validEntries &= (len(awsAccessKeyIdCandidate) != 0 and len(awsSecretAccessKeyCandidate) != 0)
|
||||
return validEntries
|
||||
|
||||
|
||||
# This is an internal class that buffers the incoming bytes into an
|
||||
# internal buffer until it gets the full desired length of bytes.
|
||||
# At that time, this bufferedReader will be reset.
|
||||
# *Error handling:
|
||||
# For retry errors (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, EAGAIN),
|
||||
# leave them to the paho _packet_read for further handling (ignored and try
|
||||
# again when data is available.
|
||||
# For other errors, leave them to the paho _packet_read for error reporting.
|
||||
|
||||
|
||||
class _BufferedReader:
|
||||
_sslSocket = None
|
||||
_internalBuffer = None
|
||||
_remainedLength = -1
|
||||
_bufferingInProgress = False
|
||||
|
||||
def __init__(self, sslSocket):
|
||||
self._sslSocket = sslSocket
|
||||
self._internalBuffer = bytearray()
|
||||
self._bufferingInProgress = False
|
||||
|
||||
def _reset(self):
|
||||
self._internalBuffer = bytearray()
|
||||
self._remainedLength = -1
|
||||
self._bufferingInProgress = False
|
||||
|
||||
def read(self, numberOfBytesToBeBuffered):
|
||||
if not self._bufferingInProgress: # If last read is completed...
|
||||
self._remainedLength = numberOfBytesToBeBuffered
|
||||
self._bufferingInProgress = True # Now we start buffering a new length of bytes
|
||||
|
||||
while self._remainedLength > 0: # Read in a loop, always try to read in the remained length
|
||||
# If the data is temporarily not available, socket.error will be raised and catched by paho
|
||||
dataChunk = self._sslSocket.read(self._remainedLength)
|
||||
# There is a chance where the server terminates the connection without closing the socket.
|
||||
# If that happens, let's raise an exception and enter the reconnect flow.
|
||||
if not dataChunk:
|
||||
raise socket.error(errno.ECONNABORTED, 0)
|
||||
self._internalBuffer.extend(dataChunk) # Buffer the data
|
||||
self._remainedLength -= len(dataChunk) # Update the remained length
|
||||
|
||||
# The requested length of bytes is buffered, recover the context and return it
|
||||
# Otherwise error should be raised
|
||||
ret = self._internalBuffer
|
||||
self._reset()
|
||||
return ret # This should always be bytearray
|
||||
|
||||
|
||||
# This is the internal class that sends requested data out chunk by chunk according
|
||||
# to the availablity of the socket write operation. If the requested bytes of data
|
||||
# (after encoding) needs to be sent out in separate socket write operations (most
|
||||
# probably be interrupted by the error socket.error (errno = ssl.SSL_ERROR_WANT_WRITE).)
|
||||
# , the write pointer is stored to ensure that the continued bytes will be sent next
|
||||
# time this function gets called.
|
||||
# *Error handling:
|
||||
# For retry errors (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, EAGAIN),
|
||||
# leave them to the paho _packet_read for further handling (ignored and try
|
||||
# again when data is available.
|
||||
# For other errors, leave them to the paho _packet_read for error reporting.
|
||||
|
||||
|
||||
class _BufferedWriter:
|
||||
_sslSocket = None
|
||||
_internalBuffer = None
|
||||
_writingInProgress = False
|
||||
_requestedDataLength = -1
|
||||
|
||||
def __init__(self, sslSocket):
|
||||
self._sslSocket = sslSocket
|
||||
self._internalBuffer = bytearray()
|
||||
self._writingInProgress = False
|
||||
self._requestedDataLength = -1
|
||||
|
||||
def _reset(self):
|
||||
self._internalBuffer = bytearray()
|
||||
self._writingInProgress = False
|
||||
self._requestedDataLength = -1
|
||||
|
||||
# Input data for this function needs to be an encoded wss frame
|
||||
# Always request for packet[pos=0:] (raw MQTT data)
|
||||
def write(self, encodedData, payloadLength):
|
||||
# encodedData should always be bytearray
|
||||
# Check if we have a frame that is partially sent
|
||||
if not self._writingInProgress:
|
||||
self._internalBuffer = encodedData
|
||||
self._writingInProgress = True
|
||||
self._requestedDataLength = payloadLength
|
||||
# Now, write as much as we can
|
||||
lengthWritten = self._sslSocket.write(self._internalBuffer)
|
||||
self._internalBuffer = self._internalBuffer[lengthWritten:]
|
||||
# This MQTT packet has been sent out in a wss frame, completely
|
||||
if len(self._internalBuffer) == 0:
|
||||
ret = self._requestedDataLength
|
||||
self._reset()
|
||||
return ret
|
||||
# This socket write is half-baked...
|
||||
else:
|
||||
return 0 # Ensure that the 'pos' inside the MQTT packet never moves since we have not finished the transmission of this encoded frame
|
||||
|
||||
|
||||
class SecuredWebSocketCore:
|
||||
# Websocket Constants
|
||||
_OP_CONTINUATION = 0x0
|
||||
_OP_TEXT = 0x1
|
||||
_OP_BINARY = 0x2
|
||||
_OP_CONNECTION_CLOSE = 0x8
|
||||
_OP_PING = 0x9
|
||||
_OP_PONG = 0xa
|
||||
# Websocket Connect Status
|
||||
_WebsocketConnectInit = -1
|
||||
_WebsocketDisconnected = 1
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, socket, hostAddress, portNumber, AWSAccessKeyID="", AWSSecretAccessKey="", AWSSessionToken=""):
|
||||
self._connectStatus = self._WebsocketConnectInit
|
||||
# Handlers
|
||||
self._sslSocket = socket
|
||||
self._sigV4Handler = self._createSigV4Core()
|
||||
self._sigV4Handler.setIAMCredentials(AWSAccessKeyID, AWSSecretAccessKey, AWSSessionToken)
|
||||
# Endpoint Info
|
||||
self._hostAddress = hostAddress
|
||||
self._portNumber = portNumber
|
||||
# Section Flags
|
||||
self._hasOpByte = False
|
||||
self._hasPayloadLengthFirst = False
|
||||
self._hasPayloadLengthExtended = False
|
||||
self._hasMaskKey = False
|
||||
self._hasPayload = False
|
||||
# Properties for current websocket frame
|
||||
self._isFIN = False
|
||||
self._RSVBits = None
|
||||
self._opCode = None
|
||||
self._needMaskKey = False
|
||||
self._payloadLengthBytesLength = 1
|
||||
self._payloadLength = 0
|
||||
self._maskKey = None
|
||||
self._payloadDataBuffer = bytearray() # Once the whole wss connection is lost, there is no need to keep the buffered payload
|
||||
try:
|
||||
self._handShake(hostAddress, portNumber)
|
||||
except wssNoKeyInEnvironmentError: # Handle SigV4 signing and websocket handshaking errors
|
||||
raise ValueError("No Access Key/KeyID Error")
|
||||
except wssHandShakeError:
|
||||
raise ValueError("Websocket Handshake Error")
|
||||
except ClientError as e:
|
||||
raise ValueError(e.message)
|
||||
# Now we have a socket with secured websocket...
|
||||
self._bufferedReader = _BufferedReader(self._sslSocket)
|
||||
self._bufferedWriter = _BufferedWriter(self._sslSocket)
|
||||
|
||||
def _createSigV4Core(self):
|
||||
return SigV4Core()
|
||||
|
||||
def _generateMaskKey(self):
|
||||
return bytearray(os.urandom(4))
|
||||
# os.urandom returns ascii str in 2.x, converted to bytearray
|
||||
# os.urandom returns bytes in 3.x, converted to bytearray
|
||||
|
||||
def _reset(self): # Reset the context for wss frame reception
|
||||
# Control info
|
||||
self._hasOpByte = False
|
||||
self._hasPayloadLengthFirst = False
|
||||
self._hasPayloadLengthExtended = False
|
||||
self._hasMaskKey = False
|
||||
self._hasPayload = False
|
||||
# Frame Info
|
||||
self._isFIN = False
|
||||
self._RSVBits = None
|
||||
self._opCode = None
|
||||
self._needMaskKey = False
|
||||
self._payloadLengthBytesLength = 1
|
||||
self._payloadLength = 0
|
||||
self._maskKey = None
|
||||
# Never reset the payloadData since we might have fragmented MQTT data from the pervious frame
|
||||
|
||||
def _generateWSSKey(self):
|
||||
return base64.b64encode(os.urandom(128)) # Bytes
|
||||
|
||||
def _verifyWSSResponse(self, response, clientKey):
|
||||
# Check if it is a 101 response
|
||||
rawResponse = response.strip().lower()
|
||||
if b"101 switching protocols" not in rawResponse or b"upgrade: websocket" not in rawResponse or b"connection: upgrade" not in rawResponse:
|
||||
return False
|
||||
# Parse out the sec-websocket-accept
|
||||
WSSAcceptKeyIndex = response.strip().index(b"sec-websocket-accept: ") + len(b"sec-websocket-accept: ")
|
||||
rawSecWebSocketAccept = response.strip()[WSSAcceptKeyIndex:].split(b"\r\n")[0].strip()
|
||||
# Verify the WSSAcceptKey
|
||||
return self._verifyWSSAcceptKey(rawSecWebSocketAccept, clientKey)
|
||||
|
||||
def _verifyWSSAcceptKey(self, srcAcceptKey, clientKey):
|
||||
GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
verifyServerAcceptKey = base64.b64encode((hashlib.sha1(clientKey + GUID)).digest()) # Bytes
|
||||
return srcAcceptKey == verifyServerAcceptKey
|
||||
|
||||
def _handShake(self, hostAddress, portNumber):
|
||||
CRLF = "\r\n"
|
||||
IOT_ENDPOINT_PATTERN = r"^[0-9a-zA-Z]+(\.ats|-ats)?\.iot\.(.*)\.amazonaws\..*"
|
||||
matched = re.compile(IOT_ENDPOINT_PATTERN, re.IGNORECASE).match(hostAddress)
|
||||
if not matched:
|
||||
raise ClientError("Invalid endpoint pattern for wss: %s" % hostAddress)
|
||||
region = matched.group(2)
|
||||
signedURL = self._sigV4Handler.createWebsocketEndpoint(hostAddress, portNumber, region, "GET", "iotdata", "/mqtt")
|
||||
# Now we got a signedURL
|
||||
path = signedURL[signedURL.index("/mqtt"):]
|
||||
# Assemble HTTP request headers
|
||||
Method = "GET " + path + " HTTP/1.1" + CRLF
|
||||
Host = "Host: " + hostAddress + CRLF
|
||||
Connection = "Connection: " + "Upgrade" + CRLF
|
||||
Upgrade = "Upgrade: " + "websocket" + CRLF
|
||||
secWebSocketVersion = "Sec-WebSocket-Version: " + "13" + CRLF
|
||||
rawSecWebSocketKey = self._generateWSSKey() # Bytes
|
||||
secWebSocketKey = "sec-websocket-key: " + rawSecWebSocketKey.decode('utf-8') + CRLF # Should be randomly generated...
|
||||
secWebSocketProtocol = "Sec-WebSocket-Protocol: " + "mqttv3.1" + CRLF
|
||||
secWebSocketExtensions = "Sec-WebSocket-Extensions: " + "permessage-deflate; client_max_window_bits" + CRLF
|
||||
# Send the HTTP request
|
||||
# Ensure that we are sending bytes, not by any chance unicode string
|
||||
handshakeBytes = Method + Host + Connection + Upgrade + secWebSocketVersion + secWebSocketProtocol + secWebSocketExtensions + secWebSocketKey + CRLF
|
||||
handshakeBytes = handshakeBytes.encode('utf-8')
|
||||
self._sslSocket.write(handshakeBytes)
|
||||
# Read it back (Non-blocking socket)
|
||||
timeStart = time.time()
|
||||
wssHandshakeResponse = bytearray()
|
||||
while len(wssHandshakeResponse) == 0:
|
||||
try:
|
||||
wssHandshakeResponse += self._sslSocket.read(1024) # Response is always less than 1024 bytes
|
||||
except socket.error as err:
|
||||
if err.errno == ssl.SSL_ERROR_WANT_READ or err.errno == ssl.SSL_ERROR_WANT_WRITE:
|
||||
if time.time() - timeStart > self._getTimeoutSec():
|
||||
raise err # We make sure that reconnect gets retried in Paho upon a wss reconnect response timeout
|
||||
else:
|
||||
raise err
|
||||
# Verify response
|
||||
# Now both wssHandshakeResponse and rawSecWebSocketKey are byte strings
|
||||
if not self._verifyWSSResponse(wssHandshakeResponse, rawSecWebSocketKey):
|
||||
raise wssHandShakeError()
|
||||
else:
|
||||
pass
|
||||
|
||||
def _getTimeoutSec(self):
|
||||
return DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC
|
||||
|
||||
# Used to create a single wss frame
|
||||
# Assume that the maximum length of a MQTT packet never exceeds the maximum length
|
||||
# for a wss frame. Therefore, the FIN bit for the encoded frame will always be 1.
|
||||
# Frames are encoded as BINARY frames.
|
||||
def _encodeFrame(self, rawPayload, opCode, masked=1):
|
||||
ret = bytearray()
|
||||
# Op byte
|
||||
opByte = 0x80 | opCode # Always a FIN, no RSV bits
|
||||
ret.append(opByte)
|
||||
# Payload Length bytes
|
||||
maskBit = masked
|
||||
payloadLength = len(rawPayload)
|
||||
if payloadLength <= 125:
|
||||
ret.append((maskBit << 7) | payloadLength)
|
||||
elif payloadLength <= 0xffff: # 16-bit unsigned int
|
||||
ret.append((maskBit << 7) | 126)
|
||||
ret.extend(struct.pack("!H", payloadLength))
|
||||
elif payloadLength <= 0x7fffffffffffffff: # 64-bit unsigned int (most significant bit must be 0)
|
||||
ret.append((maskBit << 7) | 127)
|
||||
ret.extend(struct.pack("!Q", payloadLength))
|
||||
else: # Overflow
|
||||
raise ValueError("Exceeds the maximum number of bytes for a single websocket frame.")
|
||||
if maskBit == 1:
|
||||
# Mask key bytes
|
||||
maskKey = self._generateMaskKey()
|
||||
ret.extend(maskKey)
|
||||
# Mask the payload
|
||||
payloadBytes = bytearray(rawPayload)
|
||||
if maskBit == 1:
|
||||
for i in range(0, payloadLength):
|
||||
payloadBytes[i] ^= maskKey[i % 4]
|
||||
ret.extend(payloadBytes)
|
||||
# Return the assembled wss frame
|
||||
return ret
|
||||
|
||||
# Used for the wss client to close a wss connection
|
||||
# Create and send a masked wss closing frame
|
||||
def _closeWssConnection(self):
|
||||
# Frames sent from client to server must be masked
|
||||
self._sslSocket.write(self._encodeFrame(b"", self._OP_CONNECTION_CLOSE, masked=1))
|
||||
|
||||
# Used for the wss client to respond to a wss PING from server
|
||||
# Create and send a masked PONG frame
|
||||
def _sendPONG(self):
|
||||
# Frames sent from client to server must be masked
|
||||
self._sslSocket.write(self._encodeFrame(b"", self._OP_PONG, masked=1))
|
||||
|
||||
# Override sslSocket read. Always read from the wss internal payload buffer, which
|
||||
# contains the masked MQTT packet. This read will decode ONE wss frame every time
|
||||
# and load in the payload for MQTT _packet_read. At any time, MQTT _packet_read
|
||||
# should be able to read a complete MQTT packet from the payload (buffered per wss
|
||||
# frame payload). If the MQTT packet is break into separate wss frames, different
|
||||
# chunks will be buffered in separate frames and MQTT _packet_read will not be able
|
||||
# to collect a complete MQTT packet to operate on until the necessary payload is
|
||||
# fully buffered.
|
||||
# If the requested number of bytes are not available, SSL_ERROR_WANT_READ will be
|
||||
# raised to trigger another call of _packet_read when the data is available again.
|
||||
def read(self, numberOfBytes):
|
||||
# Check if we have enough data for paho
|
||||
# _payloadDataBuffer will not be empty ony when the payload of a new wss frame
|
||||
# has been unmasked.
|
||||
if len(self._payloadDataBuffer) >= numberOfBytes:
|
||||
ret = self._payloadDataBuffer[0:numberOfBytes]
|
||||
self._payloadDataBuffer = self._payloadDataBuffer[numberOfBytes:]
|
||||
# struct.unpack(fmt, string) # Py2.x
|
||||
# struct.unpack(fmt, buffer) # Py3.x
|
||||
# Here ret is always in bytes (buffer interface)
|
||||
if sys.version_info[0] < 3: # Py2.x
|
||||
ret = str(ret)
|
||||
return ret
|
||||
# Emmm, We don't. Try to buffer from the socket (It's a new wss frame).
|
||||
if not self._hasOpByte: # Check if we need to buffer OpByte
|
||||
opByte = self._bufferedReader.read(1)
|
||||
self._isFIN = (opByte[0] & 0x80) == 0x80
|
||||
self._RSVBits = (opByte[0] & 0x70)
|
||||
self._opCode = (opByte[0] & 0x0f)
|
||||
self._hasOpByte = True # Finished buffering opByte
|
||||
# Check if any of the RSV bits are set, if so, close the connection
|
||||
# since client never sends negotiated extensions
|
||||
if self._RSVBits != 0x0:
|
||||
self._closeWssConnection()
|
||||
self._connectStatus = self._WebsocketDisconnected
|
||||
self._payloadDataBuffer = bytearray()
|
||||
raise socket.error(ssl.SSL_ERROR_WANT_READ, "RSV bits set with NO negotiated extensions.")
|
||||
if not self._hasPayloadLengthFirst: # Check if we need to buffer First Payload Length byte
|
||||
payloadLengthFirst = self._bufferedReader.read(1)
|
||||
self._hasPayloadLengthFirst = True # Finished buffering first byte of payload length
|
||||
self._needMaskKey = (payloadLengthFirst[0] & 0x80) == 0x80
|
||||
payloadLengthFirstByteArray = bytearray()
|
||||
payloadLengthFirstByteArray.extend(payloadLengthFirst)
|
||||
self._payloadLength = (payloadLengthFirstByteArray[0] & 0x7f)
|
||||
|
||||
if self._payloadLength == 126:
|
||||
self._payloadLengthBytesLength = 2
|
||||
self._hasPayloadLengthExtended = False # Force to buffer the extended
|
||||
elif self._payloadLength == 127:
|
||||
self._payloadLengthBytesLength = 8
|
||||
self._hasPayloadLengthExtended = False # Force to buffer the extended
|
||||
else: # _payloadLength <= 125:
|
||||
self._hasPayloadLengthExtended = True # No need to buffer extended payload length
|
||||
if not self._hasPayloadLengthExtended: # Check if we need to buffer Extended Payload Length bytes
|
||||
payloadLengthExtended = self._bufferedReader.read(self._payloadLengthBytesLength)
|
||||
self._hasPayloadLengthExtended = True
|
||||
if sys.version_info[0] < 3:
|
||||
payloadLengthExtended = str(payloadLengthExtended)
|
||||
if self._payloadLengthBytesLength == 2:
|
||||
self._payloadLength = struct.unpack("!H", payloadLengthExtended)[0]
|
||||
else: # _payloadLengthBytesLength == 8
|
||||
self._payloadLength = struct.unpack("!Q", payloadLengthExtended)[0]
|
||||
|
||||
if self._needMaskKey: # Response from server is masked, close the connection
|
||||
self._closeWssConnection()
|
||||
self._connectStatus = self._WebsocketDisconnected
|
||||
self._payloadDataBuffer = bytearray()
|
||||
raise socket.error(ssl.SSL_ERROR_WANT_READ, "Server response masked, closing connection and try again.")
|
||||
|
||||
if not self._hasPayload: # Check if we need to buffer the payload
|
||||
payloadForThisFrame = self._bufferedReader.read(self._payloadLength)
|
||||
self._hasPayload = True
|
||||
# Client side should never received a masked packet from the server side
|
||||
# Unmask it as needed
|
||||
#if self._needMaskKey:
|
||||
# for i in range(0, self._payloadLength):
|
||||
# payloadForThisFrame[i] ^= self._maskKey[i % 4]
|
||||
# Append it to the internal payload buffer
|
||||
self._payloadDataBuffer.extend(payloadForThisFrame)
|
||||
# Now we have the complete wss frame, reset the context
|
||||
# Check to see if it is a wss closing frame
|
||||
if self._opCode == self._OP_CONNECTION_CLOSE:
|
||||
self._connectStatus = self._WebsocketDisconnected
|
||||
self._payloadDataBuffer = bytearray() # Ensure that once the wss closing frame comes, we have nothing to read and start all over again
|
||||
# Check to see if it is a wss PING frame
|
||||
if self._opCode == self._OP_PING:
|
||||
self._sendPONG() # Nothing more to do here, if the transmission of the last wssMQTT packet is not finished, it will continue
|
||||
self._reset()
|
||||
# Check again if we have enough data for paho
|
||||
if len(self._payloadDataBuffer) >= numberOfBytes:
|
||||
ret = self._payloadDataBuffer[0:numberOfBytes]
|
||||
self._payloadDataBuffer = self._payloadDataBuffer[numberOfBytes:]
|
||||
# struct.unpack(fmt, string) # Py2.x
|
||||
# struct.unpack(fmt, buffer) # Py3.x
|
||||
# Here ret is always in bytes (buffer interface)
|
||||
if sys.version_info[0] < 3: # Py2.x
|
||||
ret = str(ret)
|
||||
return ret
|
||||
else: # Fragmented MQTT packets in separate wss frames
|
||||
raise socket.error(ssl.SSL_ERROR_WANT_READ, "Not a complete MQTT packet payload within this wss frame.")
|
||||
|
||||
def write(self, bytesToBeSent):
|
||||
# When there is a disconnection, select will report a TypeError which triggers the reconnect.
|
||||
# In reconnect, Paho will set the socket object (mocked by wss) to None, blocking other ops
|
||||
# before a connection is re-established.
|
||||
# This 'low-level' socket write op should always be able to write to plain socket.
|
||||
# Error reporting is performed by Python socket itself.
|
||||
# Wss closing frame handling is performed in the wss read.
|
||||
return self._bufferedWriter.write(self._encodeFrame(bytesToBeSent, self._OP_BINARY, 1), len(bytesToBeSent))
|
||||
|
||||
def close(self):
|
||||
if self._sslSocket is not None:
|
||||
self._sslSocket.close()
|
||||
self._sslSocket = None
|
||||
|
||||
def getpeercert(self):
|
||||
return self._sslSocket.getpeercert()
|
||||
|
||||
def getSSLSocket(self):
|
||||
if self._connectStatus != self._WebsocketDisconnected:
|
||||
return self._sslSocket
|
||||
else:
|
||||
return None # Leave the sslSocket to Paho to close it. (_ssl.close() -> wssCore.close())
|
||||
@@ -0,0 +1,244 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import ssl
|
||||
import logging
|
||||
from threading import Lock
|
||||
from numbers import Number
|
||||
import AWSIoTPythonSDK.core.protocol.paho.client as mqtt
|
||||
from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_SUCCESS
|
||||
from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids
|
||||
|
||||
|
||||
class ClientStatus(object):
|
||||
|
||||
IDLE = 0
|
||||
CONNECT = 1
|
||||
RESUBSCRIBE = 2
|
||||
DRAINING = 3
|
||||
STABLE = 4
|
||||
USER_DISCONNECT = 5
|
||||
ABNORMAL_DISCONNECT = 6
|
||||
|
||||
|
||||
class ClientStatusContainer(object):
|
||||
|
||||
def __init__(self):
|
||||
self._status = ClientStatus.IDLE
|
||||
|
||||
def get_status(self):
|
||||
return self._status
|
||||
|
||||
def set_status(self, status):
|
||||
if ClientStatus.USER_DISCONNECT == self._status: # If user requests to disconnect, no status updates other than user connect
|
||||
if ClientStatus.CONNECT == status:
|
||||
self._status = status
|
||||
else:
|
||||
self._status = status
|
||||
|
||||
|
||||
class InternalAsyncMqttClient(object):
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, client_id, clean_session, protocol, use_wss):
|
||||
self._paho_client = self._create_paho_client(client_id, clean_session, None, protocol, use_wss)
|
||||
self._use_wss = use_wss
|
||||
self._event_callback_map_lock = Lock()
|
||||
self._event_callback_map = dict()
|
||||
|
||||
def _create_paho_client(self, client_id, clean_session, user_data, protocol, use_wss):
|
||||
self._logger.debug("Initializing MQTT layer...")
|
||||
return mqtt.Client(client_id, clean_session, user_data, protocol, use_wss)
|
||||
|
||||
# TODO: Merge credentials providers configuration into one
|
||||
def set_cert_credentials_provider(self, cert_credentials_provider):
|
||||
# History issue from Yun SDK where AR9331 embedded Linux only have Python 2.7.3
|
||||
# pre-installed. In this version, TLSv1_2 is not even an option.
|
||||
# SSLv23 is a work-around which selects the highest TLS version between the client
|
||||
# and service. If user installs opensslv1.0.1+, this option will work fine for Mutual
|
||||
# Auth.
|
||||
# Note that we cannot force TLSv1.2 for Mutual Auth. in Python 2.7.3 and TLS support
|
||||
# in Python only starts from Python2.7.
|
||||
# See also: https://docs.python.org/2/library/ssl.html#ssl.PROTOCOL_SSLv23
|
||||
if self._use_wss:
|
||||
ca_path = cert_credentials_provider.get_ca_path()
|
||||
self._paho_client.tls_set(ca_certs=ca_path, cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_SSLv23)
|
||||
else:
|
||||
ca_path = cert_credentials_provider.get_ca_path()
|
||||
cert_path = cert_credentials_provider.get_cert_path()
|
||||
key_path = cert_credentials_provider.get_key_path()
|
||||
self._paho_client.tls_set(ca_certs=ca_path,certfile=cert_path, keyfile=key_path,
|
||||
cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_SSLv23)
|
||||
|
||||
def set_iam_credentials_provider(self, iam_credentials_provider):
|
||||
self._paho_client.configIAMCredentials(iam_credentials_provider.get_access_key_id(),
|
||||
iam_credentials_provider.get_secret_access_key(),
|
||||
iam_credentials_provider.get_session_token())
|
||||
|
||||
def set_endpoint_provider(self, endpoint_provider):
|
||||
self._endpoint_provider = endpoint_provider
|
||||
|
||||
def configure_last_will(self, topic, payload, qos, retain=False):
|
||||
self._paho_client.will_set(topic, payload, qos, retain)
|
||||
|
||||
def configure_alpn_protocols(self, alpn_protocols):
|
||||
self._paho_client.config_alpn_protocols(alpn_protocols)
|
||||
|
||||
def clear_last_will(self):
|
||||
self._paho_client.will_clear()
|
||||
|
||||
def set_username_password(self, username, password=None):
|
||||
self._paho_client.username_pw_set(username, password)
|
||||
|
||||
def set_socket_factory(self, socket_factory):
|
||||
self._paho_client.socket_factory_set(socket_factory)
|
||||
|
||||
def configure_reconnect_back_off(self, base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec):
|
||||
self._paho_client.setBackoffTiming(base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec)
|
||||
|
||||
def connect(self, keep_alive_sec, ack_callback=None):
|
||||
host = self._endpoint_provider.get_host()
|
||||
port = self._endpoint_provider.get_port()
|
||||
|
||||
with self._event_callback_map_lock:
|
||||
self._logger.debug("Filling in fixed event callbacks: CONNACK, DISCONNECT, MESSAGE")
|
||||
self._event_callback_map[FixedEventMids.CONNACK_MID] = self._create_combined_on_connect_callback(ack_callback)
|
||||
self._event_callback_map[FixedEventMids.DISCONNECT_MID] = self._create_combined_on_disconnect_callback(None)
|
||||
self._event_callback_map[FixedEventMids.MESSAGE_MID] = self._create_converted_on_message_callback()
|
||||
|
||||
rc = self._paho_client.connect(host, port, keep_alive_sec)
|
||||
if MQTT_ERR_SUCCESS == rc:
|
||||
self.start_background_network_io()
|
||||
|
||||
return rc
|
||||
|
||||
def start_background_network_io(self):
|
||||
self._logger.debug("Starting network I/O thread...")
|
||||
self._paho_client.loop_start()
|
||||
|
||||
def stop_background_network_io(self):
|
||||
self._logger.debug("Stopping network I/O thread...")
|
||||
self._paho_client.loop_stop()
|
||||
|
||||
def disconnect(self, ack_callback=None):
|
||||
with self._event_callback_map_lock:
|
||||
rc = self._paho_client.disconnect()
|
||||
if MQTT_ERR_SUCCESS == rc:
|
||||
self._logger.debug("Filling in custom disconnect event callback...")
|
||||
combined_on_disconnect_callback = self._create_combined_on_disconnect_callback(ack_callback)
|
||||
self._event_callback_map[FixedEventMids.DISCONNECT_MID] = combined_on_disconnect_callback
|
||||
return rc
|
||||
|
||||
def _create_combined_on_connect_callback(self, ack_callback):
|
||||
def combined_on_connect_callback(mid, data):
|
||||
self.on_online()
|
||||
if ack_callback:
|
||||
ack_callback(mid, data)
|
||||
return combined_on_connect_callback
|
||||
|
||||
def _create_combined_on_disconnect_callback(self, ack_callback):
|
||||
def combined_on_disconnect_callback(mid, data):
|
||||
self.on_offline()
|
||||
if ack_callback:
|
||||
ack_callback(mid, data)
|
||||
return combined_on_disconnect_callback
|
||||
|
||||
def _create_converted_on_message_callback(self):
|
||||
def converted_on_message_callback(mid, data):
|
||||
self.on_message(data)
|
||||
return converted_on_message_callback
|
||||
|
||||
# For client online notification
|
||||
def on_online(self):
|
||||
pass
|
||||
|
||||
# For client offline notification
|
||||
def on_offline(self):
|
||||
pass
|
||||
|
||||
# For client message reception notification
|
||||
def on_message(self, message):
|
||||
pass
|
||||
|
||||
def publish(self, topic, payload, qos, retain=False, ack_callback=None):
|
||||
with self._event_callback_map_lock:
|
||||
rc, mid = self._paho_client.publish(topic, payload, qos, retain)
|
||||
if MQTT_ERR_SUCCESS == rc and qos > 0 and ack_callback:
|
||||
self._logger.debug("Filling in custom puback (QoS>0) event callback...")
|
||||
self._event_callback_map[mid] = ack_callback
|
||||
return rc, mid
|
||||
|
||||
def subscribe(self, topic, qos, ack_callback=None):
|
||||
with self._event_callback_map_lock:
|
||||
rc, mid = self._paho_client.subscribe(topic, qos)
|
||||
if MQTT_ERR_SUCCESS == rc and ack_callback:
|
||||
self._logger.debug("Filling in custom suback event callback...")
|
||||
self._event_callback_map[mid] = ack_callback
|
||||
return rc, mid
|
||||
|
||||
def unsubscribe(self, topic, ack_callback=None):
|
||||
with self._event_callback_map_lock:
|
||||
rc, mid = self._paho_client.unsubscribe(topic)
|
||||
if MQTT_ERR_SUCCESS == rc and ack_callback:
|
||||
self._logger.debug("Filling in custom unsuback event callback...")
|
||||
self._event_callback_map[mid] = ack_callback
|
||||
return rc, mid
|
||||
|
||||
def register_internal_event_callbacks(self, on_connect, on_disconnect, on_publish, on_subscribe, on_unsubscribe, on_message):
|
||||
self._logger.debug("Registering internal event callbacks to MQTT layer...")
|
||||
self._paho_client.on_connect = on_connect
|
||||
self._paho_client.on_disconnect = on_disconnect
|
||||
self._paho_client.on_publish = on_publish
|
||||
self._paho_client.on_subscribe = on_subscribe
|
||||
self._paho_client.on_unsubscribe = on_unsubscribe
|
||||
self._paho_client.on_message = on_message
|
||||
|
||||
def unregister_internal_event_callbacks(self):
|
||||
self._logger.debug("Unregistering internal event callbacks from MQTT layer...")
|
||||
self._paho_client.on_connect = None
|
||||
self._paho_client.on_disconnect = None
|
||||
self._paho_client.on_publish = None
|
||||
self._paho_client.on_subscribe = None
|
||||
self._paho_client.on_unsubscribe = None
|
||||
self._paho_client.on_message = None
|
||||
|
||||
def invoke_event_callback(self, mid, data=None):
|
||||
with self._event_callback_map_lock:
|
||||
event_callback = self._event_callback_map.get(mid)
|
||||
# For invoking the event callback, we do not need to acquire the lock
|
||||
if event_callback:
|
||||
self._logger.debug("Invoking custom event callback...")
|
||||
if data is not None:
|
||||
event_callback(mid=mid, data=data)
|
||||
else:
|
||||
event_callback(mid=mid)
|
||||
if isinstance(mid, Number): # Do NOT remove callbacks for CONNACK/DISCONNECT/MESSAGE
|
||||
self._logger.debug("This custom event callback is for pub/sub/unsub, removing it after invocation...")
|
||||
with self._event_callback_map_lock:
|
||||
del self._event_callback_map[mid]
|
||||
|
||||
def remove_event_callback(self, mid):
|
||||
with self._event_callback_map_lock:
|
||||
if mid in self._event_callback_map:
|
||||
self._logger.debug("Removing custom event callback...")
|
||||
del self._event_callback_map[mid]
|
||||
|
||||
def clean_up_event_callbacks(self):
|
||||
with self._event_callback_map_lock:
|
||||
self._event_callback_map.clear()
|
||||
|
||||
def get_event_callback_map(self):
|
||||
return self._event_callback_map
|
||||
@@ -0,0 +1,20 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC = 30
|
||||
DEFAULT_OPERATION_TIMEOUT_SEC = 5
|
||||
DEFAULT_DRAINING_INTERNAL_SEC = 0.5
|
||||
METRICS_PREFIX = "?SDK=Python&Version="
|
||||
ALPN_PROTCOLS = "x-amzn-mqtt-ca"
|
||||
@@ -0,0 +1,29 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
class EventTypes(object):
|
||||
CONNACK = 0
|
||||
DISCONNECT = 1
|
||||
PUBACK = 2
|
||||
SUBACK = 3
|
||||
UNSUBACK = 4
|
||||
MESSAGE = 5
|
||||
|
||||
|
||||
class FixedEventMids(object):
|
||||
CONNACK_MID = "CONNECTED"
|
||||
DISCONNECT_MID = "DISCONNECTED"
|
||||
MESSAGE_MID = "MESSAGE"
|
||||
QUEUED_MID = "QUEUED"
|
||||
@@ -0,0 +1,87 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import logging
|
||||
from AWSIoTPythonSDK.core.util.enums import DropBehaviorTypes
|
||||
|
||||
|
||||
class AppendResults(object):
|
||||
APPEND_FAILURE_QUEUE_FULL = -1
|
||||
APPEND_FAILURE_QUEUE_DISABLED = -2
|
||||
APPEND_SUCCESS = 0
|
||||
|
||||
|
||||
class OfflineRequestQueue(list):
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, max_size, drop_behavior=DropBehaviorTypes.DROP_NEWEST):
|
||||
if not isinstance(max_size, int) or not isinstance(drop_behavior, int):
|
||||
self._logger.error("init: MaximumSize/DropBehavior must be integer.")
|
||||
raise TypeError("MaximumSize/DropBehavior must be integer.")
|
||||
if drop_behavior != DropBehaviorTypes.DROP_OLDEST and drop_behavior != DropBehaviorTypes.DROP_NEWEST:
|
||||
self._logger.error("init: Drop behavior not supported.")
|
||||
raise ValueError("Drop behavior not supported.")
|
||||
|
||||
list.__init__([])
|
||||
self._drop_behavior = drop_behavior
|
||||
# When self._maximumSize > 0, queue is limited
|
||||
# When self._maximumSize == 0, queue is disabled
|
||||
# When self._maximumSize < 0. queue is infinite
|
||||
self._max_size = max_size
|
||||
|
||||
def _is_enabled(self):
|
||||
return self._max_size != 0
|
||||
|
||||
def _need_drop_messages(self):
|
||||
# Need to drop messages when:
|
||||
# 1. Queue is limited and full
|
||||
# 2. Queue is disabled
|
||||
is_queue_full = len(self) >= self._max_size
|
||||
is_queue_limited = self._max_size > 0
|
||||
is_queue_disabled = not self._is_enabled()
|
||||
return (is_queue_full and is_queue_limited) or is_queue_disabled
|
||||
|
||||
def set_behavior_drop_newest(self):
|
||||
self._drop_behavior = DropBehaviorTypes.DROP_NEWEST
|
||||
|
||||
def set_behavior_drop_oldest(self):
|
||||
self._drop_behavior = DropBehaviorTypes.DROP_OLDEST
|
||||
|
||||
# Override
|
||||
# Append to a queue with a limited size.
|
||||
# Return APPEND_SUCCESS if the append is successful
|
||||
# Return APPEND_FAILURE_QUEUE_FULL if the append failed because the queue is full
|
||||
# Return APPEND_FAILURE_QUEUE_DISABLED if the append failed because the queue is disabled
|
||||
def append(self, data):
|
||||
ret = AppendResults.APPEND_SUCCESS
|
||||
if self._is_enabled():
|
||||
if self._need_drop_messages():
|
||||
# We should drop the newest
|
||||
if DropBehaviorTypes.DROP_NEWEST == self._drop_behavior:
|
||||
self._logger.warn("append: Full queue. Drop the newest: " + str(data))
|
||||
ret = AppendResults.APPEND_FAILURE_QUEUE_FULL
|
||||
# We should drop the oldest
|
||||
else:
|
||||
current_oldest = super(OfflineRequestQueue, self).pop(0)
|
||||
self._logger.warn("append: Full queue. Drop the oldest: " + str(current_oldest))
|
||||
super(OfflineRequestQueue, self).append(data)
|
||||
ret = AppendResults.APPEND_FAILURE_QUEUE_FULL
|
||||
else:
|
||||
self._logger.debug("append: Add new element: " + str(data))
|
||||
super(OfflineRequestQueue, self).append(data)
|
||||
else:
|
||||
self._logger.debug("append: Queue is disabled. Drop the message: " + str(data))
|
||||
ret = AppendResults.APPEND_FAILURE_QUEUE_DISABLED
|
||||
return ret
|
||||
@@ -0,0 +1,27 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
class RequestTypes(object):
|
||||
CONNECT = 0
|
||||
DISCONNECT = 1
|
||||
PUBLISH = 2
|
||||
SUBSCRIBE = 3
|
||||
UNSUBSCRIBE = 4
|
||||
|
||||
class QueueableRequest(object):
|
||||
|
||||
def __init__(self, type, data):
|
||||
self.type = type
|
||||
self.data = data # Can be a tuple
|
||||
@@ -0,0 +1,296 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import time
|
||||
import logging
|
||||
from threading import Thread
|
||||
from threading import Event
|
||||
from AWSIoTPythonSDK.core.protocol.internal.events import EventTypes
|
||||
from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids
|
||||
from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus
|
||||
from AWSIoTPythonSDK.core.protocol.internal.queues import OfflineRequestQueue
|
||||
from AWSIoTPythonSDK.core.protocol.internal.requests import RequestTypes
|
||||
from AWSIoTPythonSDK.core.protocol.paho.client import topic_matches_sub
|
||||
from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_DRAINING_INTERNAL_SEC
|
||||
|
||||
|
||||
class EventProducer(object):
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, cv, event_queue):
|
||||
self._cv = cv
|
||||
self._event_queue = event_queue
|
||||
|
||||
def on_connect(self, client, user_data, flags, rc):
|
||||
self._add_to_queue(FixedEventMids.CONNACK_MID, EventTypes.CONNACK, rc)
|
||||
self._logger.debug("Produced [connack] event")
|
||||
|
||||
def on_disconnect(self, client, user_data, rc):
|
||||
self._add_to_queue(FixedEventMids.DISCONNECT_MID, EventTypes.DISCONNECT, rc)
|
||||
self._logger.debug("Produced [disconnect] event")
|
||||
|
||||
def on_publish(self, client, user_data, mid):
|
||||
self._add_to_queue(mid, EventTypes.PUBACK, None)
|
||||
self._logger.debug("Produced [puback] event")
|
||||
|
||||
def on_subscribe(self, client, user_data, mid, granted_qos):
|
||||
self._add_to_queue(mid, EventTypes.SUBACK, granted_qos)
|
||||
self._logger.debug("Produced [suback] event")
|
||||
|
||||
def on_unsubscribe(self, client, user_data, mid):
|
||||
self._add_to_queue(mid, EventTypes.UNSUBACK, None)
|
||||
self._logger.debug("Produced [unsuback] event")
|
||||
|
||||
def on_message(self, client, user_data, message):
|
||||
self._add_to_queue(FixedEventMids.MESSAGE_MID, EventTypes.MESSAGE, message)
|
||||
self._logger.debug("Produced [message] event")
|
||||
|
||||
def _add_to_queue(self, mid, event_type, data):
|
||||
with self._cv:
|
||||
self._event_queue.put((mid, event_type, data))
|
||||
self._cv.notify()
|
||||
|
||||
|
||||
class EventConsumer(object):
|
||||
|
||||
MAX_DISPATCH_INTERNAL_SEC = 0.01
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, cv, event_queue, internal_async_client,
|
||||
subscription_manager, offline_requests_manager, client_status):
|
||||
self._cv = cv
|
||||
self._event_queue = event_queue
|
||||
self._internal_async_client = internal_async_client
|
||||
self._subscription_manager = subscription_manager
|
||||
self._offline_requests_manager = offline_requests_manager
|
||||
self._client_status = client_status
|
||||
self._is_running = False
|
||||
self._draining_interval_sec = DEFAULT_DRAINING_INTERNAL_SEC
|
||||
self._dispatch_methods = {
|
||||
EventTypes.CONNACK : self._dispatch_connack,
|
||||
EventTypes.DISCONNECT : self._dispatch_disconnect,
|
||||
EventTypes.PUBACK : self._dispatch_puback,
|
||||
EventTypes.SUBACK : self._dispatch_suback,
|
||||
EventTypes.UNSUBACK : self._dispatch_unsuback,
|
||||
EventTypes.MESSAGE : self._dispatch_message
|
||||
}
|
||||
self._offline_request_handlers = {
|
||||
RequestTypes.PUBLISH : self._handle_offline_publish,
|
||||
RequestTypes.SUBSCRIBE : self._handle_offline_subscribe,
|
||||
RequestTypes.UNSUBSCRIBE : self._handle_offline_unsubscribe
|
||||
}
|
||||
self._stopper = Event()
|
||||
|
||||
def update_offline_requests_manager(self, offline_requests_manager):
|
||||
self._offline_requests_manager = offline_requests_manager
|
||||
|
||||
def update_draining_interval_sec(self, draining_interval_sec):
|
||||
self._draining_interval_sec = draining_interval_sec
|
||||
|
||||
def get_draining_interval_sec(self):
|
||||
return self._draining_interval_sec
|
||||
|
||||
def is_running(self):
|
||||
return self._is_running
|
||||
|
||||
def start(self):
|
||||
self._stopper.clear()
|
||||
self._is_running = True
|
||||
dispatch_events = Thread(target=self._dispatch)
|
||||
dispatch_events.daemon = True
|
||||
dispatch_events.start()
|
||||
self._logger.debug("Event consuming thread started")
|
||||
|
||||
def stop(self):
|
||||
if self._is_running:
|
||||
self._is_running = False
|
||||
self._clean_up()
|
||||
self._logger.debug("Event consuming thread stopped")
|
||||
|
||||
def _clean_up(self):
|
||||
self._logger.debug("Cleaning up before stopping event consuming")
|
||||
with self._event_queue.mutex:
|
||||
self._event_queue.queue.clear()
|
||||
self._logger.debug("Event queue cleared")
|
||||
self._internal_async_client.stop_background_network_io()
|
||||
self._logger.debug("Network thread stopped")
|
||||
self._internal_async_client.clean_up_event_callbacks()
|
||||
self._logger.debug("Event callbacks cleared")
|
||||
|
||||
def wait_until_it_stops(self, timeout_sec):
|
||||
self._logger.debug("Waiting for event consumer to completely stop")
|
||||
return self._stopper.wait(timeout=timeout_sec)
|
||||
|
||||
def is_fully_stopped(self):
|
||||
return self._stopper.is_set()
|
||||
|
||||
def _dispatch(self):
|
||||
while self._is_running:
|
||||
with self._cv:
|
||||
if self._event_queue.empty():
|
||||
self._cv.wait(self.MAX_DISPATCH_INTERNAL_SEC)
|
||||
else:
|
||||
while not self._event_queue.empty():
|
||||
self._dispatch_one()
|
||||
self._stopper.set()
|
||||
self._logger.debug("Exiting dispatching loop...")
|
||||
|
||||
def _dispatch_one(self):
|
||||
mid, event_type, data = self._event_queue.get()
|
||||
if mid:
|
||||
self._dispatch_methods[event_type](mid, data)
|
||||
self._internal_async_client.invoke_event_callback(mid, data=data)
|
||||
# We need to make sure disconnect event gets dispatched and then we stop the consumer
|
||||
if self._need_to_stop_dispatching(mid):
|
||||
self.stop()
|
||||
|
||||
def _need_to_stop_dispatching(self, mid):
|
||||
status = self._client_status.get_status()
|
||||
return (ClientStatus.USER_DISCONNECT == status or ClientStatus.CONNECT == status) \
|
||||
and mid == FixedEventMids.DISCONNECT_MID
|
||||
|
||||
def _dispatch_connack(self, mid, rc):
|
||||
status = self._client_status.get_status()
|
||||
self._logger.debug("Dispatching [connack] event")
|
||||
if self._need_recover():
|
||||
if ClientStatus.STABLE != status: # To avoid multiple connack dispatching
|
||||
self._logger.debug("Has recovery job")
|
||||
clean_up_debt = Thread(target=self._clean_up_debt)
|
||||
clean_up_debt.start()
|
||||
else:
|
||||
self._logger.debug("No need for recovery")
|
||||
self._client_status.set_status(ClientStatus.STABLE)
|
||||
|
||||
def _need_recover(self):
|
||||
return self._subscription_manager.list_records() or self._offline_requests_manager.has_more()
|
||||
|
||||
def _clean_up_debt(self):
|
||||
self._handle_resubscribe()
|
||||
self._handle_draining()
|
||||
self._client_status.set_status(ClientStatus.STABLE)
|
||||
|
||||
def _handle_resubscribe(self):
|
||||
subscriptions = self._subscription_manager.list_records()
|
||||
if subscriptions and not self._has_user_disconnect_request():
|
||||
self._logger.debug("Start resubscribing")
|
||||
self._client_status.set_status(ClientStatus.RESUBSCRIBE)
|
||||
for topic, (qos, message_callback, ack_callback) in subscriptions:
|
||||
if self._has_user_disconnect_request():
|
||||
self._logger.debug("User disconnect detected")
|
||||
break
|
||||
self._internal_async_client.subscribe(topic, qos, ack_callback)
|
||||
|
||||
def _handle_draining(self):
|
||||
if self._offline_requests_manager.has_more() and not self._has_user_disconnect_request():
|
||||
self._logger.debug("Start draining")
|
||||
self._client_status.set_status(ClientStatus.DRAINING)
|
||||
while self._offline_requests_manager.has_more():
|
||||
if self._has_user_disconnect_request():
|
||||
self._logger.debug("User disconnect detected")
|
||||
break
|
||||
offline_request = self._offline_requests_manager.get_next()
|
||||
if offline_request:
|
||||
self._offline_request_handlers[offline_request.type](offline_request)
|
||||
time.sleep(self._draining_interval_sec)
|
||||
|
||||
def _has_user_disconnect_request(self):
|
||||
return ClientStatus.USER_DISCONNECT == self._client_status.get_status()
|
||||
|
||||
def _dispatch_disconnect(self, mid, rc):
|
||||
self._logger.debug("Dispatching [disconnect] event")
|
||||
status = self._client_status.get_status()
|
||||
if ClientStatus.USER_DISCONNECT == status or ClientStatus.CONNECT == status:
|
||||
pass
|
||||
else:
|
||||
self._client_status.set_status(ClientStatus.ABNORMAL_DISCONNECT)
|
||||
|
||||
# For puback, suback and unsuback, ack callback invocation is handled in dispatch_one
|
||||
# Do nothing in the event dispatching itself
|
||||
def _dispatch_puback(self, mid, rc):
|
||||
self._logger.debug("Dispatching [puback] event")
|
||||
|
||||
def _dispatch_suback(self, mid, rc):
|
||||
self._logger.debug("Dispatching [suback] event")
|
||||
|
||||
def _dispatch_unsuback(self, mid, rc):
|
||||
self._logger.debug("Dispatching [unsuback] event")
|
||||
|
||||
def _dispatch_message(self, mid, message):
|
||||
self._logger.debug("Dispatching [message] event")
|
||||
subscriptions = self._subscription_manager.list_records()
|
||||
if subscriptions:
|
||||
for topic, (qos, message_callback, _) in subscriptions:
|
||||
if topic_matches_sub(topic, message.topic) and message_callback:
|
||||
message_callback(None, None, message) # message_callback(client, userdata, message)
|
||||
|
||||
def _handle_offline_publish(self, request):
|
||||
topic, payload, qos, retain = request.data
|
||||
self._internal_async_client.publish(topic, payload, qos, retain)
|
||||
self._logger.debug("Processed offline publish request")
|
||||
|
||||
def _handle_offline_subscribe(self, request):
|
||||
topic, qos, message_callback, ack_callback = request.data
|
||||
self._subscription_manager.add_record(topic, qos, message_callback, ack_callback)
|
||||
self._internal_async_client.subscribe(topic, qos, ack_callback)
|
||||
self._logger.debug("Processed offline subscribe request")
|
||||
|
||||
def _handle_offline_unsubscribe(self, request):
|
||||
topic, ack_callback = request.data
|
||||
self._subscription_manager.remove_record(topic)
|
||||
self._internal_async_client.unsubscribe(topic, ack_callback)
|
||||
self._logger.debug("Processed offline unsubscribe request")
|
||||
|
||||
|
||||
class SubscriptionManager(object):
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self):
|
||||
self._subscription_map = dict()
|
||||
|
||||
def add_record(self, topic, qos, message_callback, ack_callback):
|
||||
self._logger.debug("Adding a new subscription record: %s qos: %d", topic, qos)
|
||||
self._subscription_map[topic] = qos, message_callback, ack_callback # message_callback and/or ack_callback could be None
|
||||
|
||||
def remove_record(self, topic):
|
||||
self._logger.debug("Removing subscription record: %s", topic)
|
||||
if self._subscription_map.get(topic): # Ignore topics that are never subscribed to
|
||||
del self._subscription_map[topic]
|
||||
else:
|
||||
self._logger.warn("Removing attempt for non-exist subscription record: %s", topic)
|
||||
|
||||
def list_records(self):
|
||||
return list(self._subscription_map.items())
|
||||
|
||||
|
||||
class OfflineRequestsManager(object):
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, max_size, drop_behavior):
|
||||
self._queue = OfflineRequestQueue(max_size, drop_behavior)
|
||||
|
||||
def has_more(self):
|
||||
return len(self._queue) > 0
|
||||
|
||||
def add_one(self, request):
|
||||
return self._queue.append(request)
|
||||
|
||||
def get_next(self):
|
||||
if self.has_more():
|
||||
return self._queue.pop(0)
|
||||
else:
|
||||
return None
|
||||
@@ -0,0 +1,373 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import AWSIoTPythonSDK
|
||||
from AWSIoTPythonSDK.core.protocol.internal.clients import InternalAsyncMqttClient
|
||||
from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatusContainer
|
||||
from AWSIoTPythonSDK.core.protocol.internal.clients import ClientStatus
|
||||
from AWSIoTPythonSDK.core.protocol.internal.workers import EventProducer
|
||||
from AWSIoTPythonSDK.core.protocol.internal.workers import EventConsumer
|
||||
from AWSIoTPythonSDK.core.protocol.internal.workers import SubscriptionManager
|
||||
from AWSIoTPythonSDK.core.protocol.internal.workers import OfflineRequestsManager
|
||||
from AWSIoTPythonSDK.core.protocol.internal.requests import RequestTypes
|
||||
from AWSIoTPythonSDK.core.protocol.internal.requests import QueueableRequest
|
||||
from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC
|
||||
from AWSIoTPythonSDK.core.protocol.internal.defaults import DEFAULT_OPERATION_TIMEOUT_SEC
|
||||
from AWSIoTPythonSDK.core.protocol.internal.defaults import METRICS_PREFIX
|
||||
from AWSIoTPythonSDK.core.protocol.internal.defaults import ALPN_PROTCOLS
|
||||
from AWSIoTPythonSDK.core.protocol.internal.events import FixedEventMids
|
||||
from AWSIoTPythonSDK.core.protocol.paho.client import MQTT_ERR_SUCCESS
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import connectError
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import connectTimeoutException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import disconnectError
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import disconnectTimeoutException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishError
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishTimeoutException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueFullException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import publishQueueDisabledException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeQueueFullException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeQueueDisabledException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeQueueFullException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeQueueDisabledException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeError
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import subscribeTimeoutException
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeError
|
||||
from AWSIoTPythonSDK.exception.AWSIoTExceptions import unsubscribeTimeoutException
|
||||
from AWSIoTPythonSDK.core.protocol.internal.queues import AppendResults
|
||||
from AWSIoTPythonSDK.core.util.enums import DropBehaviorTypes
|
||||
from AWSIoTPythonSDK.core.protocol.paho.client import MQTTv31
|
||||
from threading import Condition
|
||||
from threading import Event
|
||||
import logging
|
||||
import sys
|
||||
if sys.version_info[0] < 3:
|
||||
from Queue import Queue
|
||||
else:
|
||||
from queue import Queue
|
||||
|
||||
|
||||
class MqttCore(object):
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, client_id, clean_session, protocol, use_wss):
|
||||
self._use_wss = use_wss
|
||||
self._username = ""
|
||||
self._password = None
|
||||
self._enable_metrics_collection = True
|
||||
self._event_queue = Queue()
|
||||
self._event_cv = Condition()
|
||||
self._event_producer = EventProducer(self._event_cv, self._event_queue)
|
||||
self._client_status = ClientStatusContainer()
|
||||
self._internal_async_client = InternalAsyncMqttClient(client_id, clean_session, protocol, use_wss)
|
||||
self._subscription_manager = SubscriptionManager()
|
||||
self._offline_requests_manager = OfflineRequestsManager(-1, DropBehaviorTypes.DROP_NEWEST) # Infinite queue
|
||||
self._event_consumer = EventConsumer(self._event_cv,
|
||||
self._event_queue,
|
||||
self._internal_async_client,
|
||||
self._subscription_manager,
|
||||
self._offline_requests_manager,
|
||||
self._client_status)
|
||||
self._connect_disconnect_timeout_sec = DEFAULT_CONNECT_DISCONNECT_TIMEOUT_SEC
|
||||
self._operation_timeout_sec = DEFAULT_OPERATION_TIMEOUT_SEC
|
||||
self._init_offline_request_exceptions()
|
||||
self._init_workers()
|
||||
self._logger.info("MqttCore initialized")
|
||||
self._logger.info("Client id: %s" % client_id)
|
||||
self._logger.info("Protocol version: %s" % ("MQTTv3.1" if protocol == MQTTv31 else "MQTTv3.1.1"))
|
||||
self._logger.info("Authentication type: %s" % ("SigV4 WebSocket" if use_wss else "TLSv1.2 certificate based Mutual Auth."))
|
||||
|
||||
def _init_offline_request_exceptions(self):
|
||||
self._offline_request_queue_disabled_exceptions = {
|
||||
RequestTypes.PUBLISH : publishQueueDisabledException(),
|
||||
RequestTypes.SUBSCRIBE : subscribeQueueDisabledException(),
|
||||
RequestTypes.UNSUBSCRIBE : unsubscribeQueueDisabledException()
|
||||
}
|
||||
self._offline_request_queue_full_exceptions = {
|
||||
RequestTypes.PUBLISH : publishQueueFullException(),
|
||||
RequestTypes.SUBSCRIBE : subscribeQueueFullException(),
|
||||
RequestTypes.UNSUBSCRIBE : unsubscribeQueueFullException()
|
||||
}
|
||||
|
||||
def _init_workers(self):
|
||||
self._internal_async_client.register_internal_event_callbacks(self._event_producer.on_connect,
|
||||
self._event_producer.on_disconnect,
|
||||
self._event_producer.on_publish,
|
||||
self._event_producer.on_subscribe,
|
||||
self._event_producer.on_unsubscribe,
|
||||
self._event_producer.on_message)
|
||||
|
||||
def _start_workers(self):
|
||||
self._event_consumer.start()
|
||||
|
||||
def use_wss(self):
|
||||
return self._use_wss
|
||||
|
||||
# Used for general message event reception
|
||||
def on_message(self, message):
|
||||
pass
|
||||
|
||||
# Used for general online event notification
|
||||
def on_online(self):
|
||||
pass
|
||||
|
||||
# Used for general offline event notification
|
||||
def on_offline(self):
|
||||
pass
|
||||
|
||||
def configure_cert_credentials(self, cert_credentials_provider):
|
||||
self._logger.info("Configuring certificates...")
|
||||
self._internal_async_client.set_cert_credentials_provider(cert_credentials_provider)
|
||||
|
||||
def configure_iam_credentials(self, iam_credentials_provider):
|
||||
self._logger.info("Configuring custom IAM credentials...")
|
||||
self._internal_async_client.set_iam_credentials_provider(iam_credentials_provider)
|
||||
|
||||
def configure_endpoint(self, endpoint_provider):
|
||||
self._logger.info("Configuring endpoint...")
|
||||
self._internal_async_client.set_endpoint_provider(endpoint_provider)
|
||||
|
||||
def configure_connect_disconnect_timeout_sec(self, connect_disconnect_timeout_sec):
|
||||
self._logger.info("Configuring connect/disconnect time out: %f sec" % connect_disconnect_timeout_sec)
|
||||
self._connect_disconnect_timeout_sec = connect_disconnect_timeout_sec
|
||||
|
||||
def configure_operation_timeout_sec(self, operation_timeout_sec):
|
||||
self._logger.info("Configuring MQTT operation time out: %f sec" % operation_timeout_sec)
|
||||
self._operation_timeout_sec = operation_timeout_sec
|
||||
|
||||
def configure_reconnect_back_off(self, base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec):
|
||||
self._logger.info("Configuring reconnect back off timing...")
|
||||
self._logger.info("Base quiet time: %f sec" % base_reconnect_quiet_sec)
|
||||
self._logger.info("Max quiet time: %f sec" % max_reconnect_quiet_sec)
|
||||
self._logger.info("Stable connection time: %f sec" % stable_connection_sec)
|
||||
self._internal_async_client.configure_reconnect_back_off(base_reconnect_quiet_sec, max_reconnect_quiet_sec, stable_connection_sec)
|
||||
|
||||
def configure_alpn_protocols(self):
|
||||
self._logger.info("Configuring alpn protocols...")
|
||||
self._internal_async_client.configure_alpn_protocols([ALPN_PROTCOLS])
|
||||
|
||||
def configure_last_will(self, topic, payload, qos, retain=False):
|
||||
self._logger.info("Configuring last will...")
|
||||
self._internal_async_client.configure_last_will(topic, payload, qos, retain)
|
||||
|
||||
def clear_last_will(self):
|
||||
self._logger.info("Clearing last will...")
|
||||
self._internal_async_client.clear_last_will()
|
||||
|
||||
def configure_username_password(self, username, password=None):
|
||||
self._logger.info("Configuring username and password...")
|
||||
self._username = username
|
||||
self._password = password
|
||||
|
||||
def configure_socket_factory(self, socket_factory):
|
||||
self._logger.info("Configuring socket factory...")
|
||||
self._internal_async_client.set_socket_factory(socket_factory)
|
||||
|
||||
def enable_metrics_collection(self):
|
||||
self._enable_metrics_collection = True
|
||||
|
||||
def disable_metrics_collection(self):
|
||||
self._enable_metrics_collection = False
|
||||
|
||||
def configure_offline_requests_queue(self, max_size, drop_behavior):
|
||||
self._logger.info("Configuring offline requests queueing: max queue size: %d", max_size)
|
||||
self._offline_requests_manager = OfflineRequestsManager(max_size, drop_behavior)
|
||||
self._event_consumer.update_offline_requests_manager(self._offline_requests_manager)
|
||||
|
||||
def configure_draining_interval_sec(self, draining_interval_sec):
|
||||
self._logger.info("Configuring offline requests queue draining interval: %f sec", draining_interval_sec)
|
||||
self._event_consumer.update_draining_interval_sec(draining_interval_sec)
|
||||
|
||||
def connect(self, keep_alive_sec):
|
||||
self._logger.info("Performing sync connect...")
|
||||
event = Event()
|
||||
self.connect_async(keep_alive_sec, self._create_blocking_ack_callback(event))
|
||||
if not event.wait(self._connect_disconnect_timeout_sec):
|
||||
self._logger.error("Connect timed out")
|
||||
raise connectTimeoutException()
|
||||
return True
|
||||
|
||||
def connect_async(self, keep_alive_sec, ack_callback=None):
|
||||
self._logger.info("Performing async connect...")
|
||||
self._logger.info("Keep-alive: %f sec" % keep_alive_sec)
|
||||
self._start_workers()
|
||||
self._load_callbacks()
|
||||
self._load_username_password()
|
||||
|
||||
try:
|
||||
self._client_status.set_status(ClientStatus.CONNECT)
|
||||
rc = self._internal_async_client.connect(keep_alive_sec, ack_callback)
|
||||
if MQTT_ERR_SUCCESS != rc:
|
||||
self._logger.error("Connect error: %d", rc)
|
||||
raise connectError(rc)
|
||||
except Exception as e:
|
||||
# Provided any error in connect, we should clean up the threads that have been created
|
||||
self._event_consumer.stop()
|
||||
if not self._event_consumer.wait_until_it_stops(self._connect_disconnect_timeout_sec):
|
||||
self._logger.error("Time out in waiting for event consumer to stop")
|
||||
else:
|
||||
self._logger.debug("Event consumer stopped")
|
||||
self._client_status.set_status(ClientStatus.IDLE)
|
||||
raise e
|
||||
|
||||
return FixedEventMids.CONNACK_MID
|
||||
|
||||
def _load_callbacks(self):
|
||||
self._logger.debug("Passing in general notification callbacks to internal client...")
|
||||
self._internal_async_client.on_online = self.on_online
|
||||
self._internal_async_client.on_offline = self.on_offline
|
||||
self._internal_async_client.on_message = self.on_message
|
||||
|
||||
def _load_username_password(self):
|
||||
username_candidate = self._username
|
||||
if self._enable_metrics_collection:
|
||||
username_candidate += METRICS_PREFIX
|
||||
username_candidate += AWSIoTPythonSDK.__version__
|
||||
self._internal_async_client.set_username_password(username_candidate, self._password)
|
||||
|
||||
def disconnect(self):
|
||||
self._logger.info("Performing sync disconnect...")
|
||||
event = Event()
|
||||
self.disconnect_async(self._create_blocking_ack_callback(event))
|
||||
if not event.wait(self._connect_disconnect_timeout_sec):
|
||||
self._logger.error("Disconnect timed out")
|
||||
raise disconnectTimeoutException()
|
||||
if not self._event_consumer.wait_until_it_stops(self._connect_disconnect_timeout_sec):
|
||||
self._logger.error("Disconnect timed out in waiting for event consumer")
|
||||
raise disconnectTimeoutException()
|
||||
return True
|
||||
|
||||
def disconnect_async(self, ack_callback=None):
|
||||
self._logger.info("Performing async disconnect...")
|
||||
self._client_status.set_status(ClientStatus.USER_DISCONNECT)
|
||||
rc = self._internal_async_client.disconnect(ack_callback)
|
||||
if MQTT_ERR_SUCCESS != rc:
|
||||
self._logger.error("Disconnect error: %d", rc)
|
||||
raise disconnectError(rc)
|
||||
return FixedEventMids.DISCONNECT_MID
|
||||
|
||||
def publish(self, topic, payload, qos, retain=False):
|
||||
self._logger.info("Performing sync publish...")
|
||||
ret = False
|
||||
if ClientStatus.STABLE != self._client_status.get_status():
|
||||
self._handle_offline_request(RequestTypes.PUBLISH, (topic, payload, qos, retain))
|
||||
else:
|
||||
if qos > 0:
|
||||
event = Event()
|
||||
rc, mid = self._publish_async(topic, payload, qos, retain, self._create_blocking_ack_callback(event))
|
||||
if not event.wait(self._operation_timeout_sec):
|
||||
self._internal_async_client.remove_event_callback(mid)
|
||||
self._logger.error("Publish timed out")
|
||||
raise publishTimeoutException()
|
||||
else:
|
||||
self._publish_async(topic, payload, qos, retain)
|
||||
ret = True
|
||||
return ret
|
||||
|
||||
def publish_async(self, topic, payload, qos, retain=False, ack_callback=None):
|
||||
self._logger.info("Performing async publish...")
|
||||
if ClientStatus.STABLE != self._client_status.get_status():
|
||||
self._handle_offline_request(RequestTypes.PUBLISH, (topic, payload, qos, retain))
|
||||
return FixedEventMids.QUEUED_MID
|
||||
else:
|
||||
rc, mid = self._publish_async(topic, payload, qos, retain, ack_callback)
|
||||
return mid
|
||||
|
||||
def _publish_async(self, topic, payload, qos, retain=False, ack_callback=None):
|
||||
rc, mid = self._internal_async_client.publish(topic, payload, qos, retain, ack_callback)
|
||||
if MQTT_ERR_SUCCESS != rc:
|
||||
self._logger.error("Publish error: %d", rc)
|
||||
raise publishError(rc)
|
||||
return rc, mid
|
||||
|
||||
def subscribe(self, topic, qos, message_callback=None):
|
||||
self._logger.info("Performing sync subscribe...")
|
||||
ret = False
|
||||
if ClientStatus.STABLE != self._client_status.get_status():
|
||||
self._handle_offline_request(RequestTypes.SUBSCRIBE, (topic, qos, message_callback, None))
|
||||
else:
|
||||
event = Event()
|
||||
rc, mid = self._subscribe_async(topic, qos, self._create_blocking_ack_callback(event), message_callback)
|
||||
if not event.wait(self._operation_timeout_sec):
|
||||
self._internal_async_client.remove_event_callback(mid)
|
||||
self._logger.error("Subscribe timed out")
|
||||
raise subscribeTimeoutException()
|
||||
ret = True
|
||||
return ret
|
||||
|
||||
def subscribe_async(self, topic, qos, ack_callback=None, message_callback=None):
|
||||
self._logger.info("Performing async subscribe...")
|
||||
if ClientStatus.STABLE != self._client_status.get_status():
|
||||
self._handle_offline_request(RequestTypes.SUBSCRIBE, (topic, qos, message_callback, ack_callback))
|
||||
return FixedEventMids.QUEUED_MID
|
||||
else:
|
||||
rc, mid = self._subscribe_async(topic, qos, ack_callback, message_callback)
|
||||
return mid
|
||||
|
||||
def _subscribe_async(self, topic, qos, ack_callback=None, message_callback=None):
|
||||
self._subscription_manager.add_record(topic, qos, message_callback, ack_callback)
|
||||
rc, mid = self._internal_async_client.subscribe(topic, qos, ack_callback)
|
||||
if MQTT_ERR_SUCCESS != rc:
|
||||
self._logger.error("Subscribe error: %d", rc)
|
||||
raise subscribeError(rc)
|
||||
return rc, mid
|
||||
|
||||
def unsubscribe(self, topic):
|
||||
self._logger.info("Performing sync unsubscribe...")
|
||||
ret = False
|
||||
if ClientStatus.STABLE != self._client_status.get_status():
|
||||
self._handle_offline_request(RequestTypes.UNSUBSCRIBE, (topic, None))
|
||||
else:
|
||||
event = Event()
|
||||
rc, mid = self._unsubscribe_async(topic, self._create_blocking_ack_callback(event))
|
||||
if not event.wait(self._operation_timeout_sec):
|
||||
self._internal_async_client.remove_event_callback(mid)
|
||||
self._logger.error("Unsubscribe timed out")
|
||||
raise unsubscribeTimeoutException()
|
||||
ret = True
|
||||
return ret
|
||||
|
||||
def unsubscribe_async(self, topic, ack_callback=None):
|
||||
self._logger.info("Performing async unsubscribe...")
|
||||
if ClientStatus.STABLE != self._client_status.get_status():
|
||||
self._handle_offline_request(RequestTypes.UNSUBSCRIBE, (topic, ack_callback))
|
||||
return FixedEventMids.QUEUED_MID
|
||||
else:
|
||||
rc, mid = self._unsubscribe_async(topic, ack_callback)
|
||||
return mid
|
||||
|
||||
def _unsubscribe_async(self, topic, ack_callback=None):
|
||||
self._subscription_manager.remove_record(topic)
|
||||
rc, mid = self._internal_async_client.unsubscribe(topic, ack_callback)
|
||||
if MQTT_ERR_SUCCESS != rc:
|
||||
self._logger.error("Unsubscribe error: %d", rc)
|
||||
raise unsubscribeError(rc)
|
||||
return rc, mid
|
||||
|
||||
def _create_blocking_ack_callback(self, event):
|
||||
def ack_callback(mid, data=None):
|
||||
event.set()
|
||||
return ack_callback
|
||||
|
||||
def _handle_offline_request(self, type, data):
|
||||
self._logger.info("Offline request detected!")
|
||||
offline_request = QueueableRequest(type, data)
|
||||
append_result = self._offline_requests_manager.add_one(offline_request)
|
||||
if AppendResults.APPEND_FAILURE_QUEUE_DISABLED == append_result:
|
||||
self._logger.error("Offline request queue has been disabled")
|
||||
raise self._offline_request_queue_disabled_exceptions[type]
|
||||
if AppendResults.APPEND_FAILURE_QUEUE_FULL == append_result:
|
||||
self._logger.error("Offline request queue is full")
|
||||
raise self._offline_request_queue_full_exceptions[type]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,430 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from threading import Timer, Lock, Thread
|
||||
|
||||
|
||||
class _shadowRequestToken:
|
||||
|
||||
URN_PREFIX_LENGTH = 9
|
||||
|
||||
def getNextToken(self):
|
||||
return uuid.uuid4().urn[self.URN_PREFIX_LENGTH:] # We only need the uuid digits, not the urn prefix
|
||||
|
||||
|
||||
class _basicJSONParser:
|
||||
|
||||
def setString(self, srcString):
|
||||
self._rawString = srcString
|
||||
self._dictionObject = None
|
||||
|
||||
def regenerateString(self):
|
||||
return json.dumps(self._dictionaryObject)
|
||||
|
||||
def getAttributeValue(self, srcAttributeKey):
|
||||
return self._dictionaryObject.get(srcAttributeKey)
|
||||
|
||||
def setAttributeValue(self, srcAttributeKey, srcAttributeValue):
|
||||
self._dictionaryObject[srcAttributeKey] = srcAttributeValue
|
||||
|
||||
def validateJSON(self):
|
||||
try:
|
||||
self._dictionaryObject = json.loads(self._rawString)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class deviceShadow:
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, srcShadowName, srcIsPersistentSubscribe, srcShadowManager):
|
||||
"""
|
||||
|
||||
The class that denotes a local/client-side device shadow instance.
|
||||
|
||||
Users can perform shadow operations on this instance to retrieve and modify the
|
||||
corresponding shadow JSON document in AWS IoT Cloud. The following shadow operations
|
||||
are available:
|
||||
|
||||
- Get
|
||||
|
||||
- Update
|
||||
|
||||
- Delete
|
||||
|
||||
- Listen on delta
|
||||
|
||||
- Cancel listening on delta
|
||||
|
||||
This is returned from :code:`AWSIoTPythonSDK.MQTTLib.AWSIoTMQTTShadowClient.createShadowWithName` function call.
|
||||
No need to call directly from user scripts.
|
||||
|
||||
"""
|
||||
if srcShadowName is None or srcIsPersistentSubscribe is None or srcShadowManager is None:
|
||||
raise TypeError("None type inputs detected.")
|
||||
self._shadowName = srcShadowName
|
||||
# Tool handler
|
||||
self._shadowManagerHandler = srcShadowManager
|
||||
self._basicJSONParserHandler = _basicJSONParser()
|
||||
self._tokenHandler = _shadowRequestToken()
|
||||
# Properties
|
||||
self._isPersistentSubscribe = srcIsPersistentSubscribe
|
||||
self._lastVersionInSync = -1 # -1 means not initialized
|
||||
self._isGetSubscribed = False
|
||||
self._isUpdateSubscribed = False
|
||||
self._isDeleteSubscribed = False
|
||||
self._shadowSubscribeCallbackTable = dict()
|
||||
self._shadowSubscribeCallbackTable["get"] = None
|
||||
self._shadowSubscribeCallbackTable["delete"] = None
|
||||
self._shadowSubscribeCallbackTable["update"] = None
|
||||
self._shadowSubscribeCallbackTable["delta"] = None
|
||||
self._shadowSubscribeStatusTable = dict()
|
||||
self._shadowSubscribeStatusTable["get"] = 0
|
||||
self._shadowSubscribeStatusTable["delete"] = 0
|
||||
self._shadowSubscribeStatusTable["update"] = 0
|
||||
self._tokenPool = dict()
|
||||
self._dataStructureLock = Lock()
|
||||
|
||||
def _doNonPersistentUnsubscribe(self, currentAction):
|
||||
self._shadowManagerHandler.basicShadowUnsubscribe(self._shadowName, currentAction)
|
||||
self._logger.info("Unsubscribed to " + currentAction + " accepted/rejected topics for deviceShadow: " + self._shadowName)
|
||||
|
||||
def generalCallback(self, client, userdata, message):
|
||||
# In Py3.x, message.payload comes in as a bytes(string)
|
||||
# json.loads needs a string input
|
||||
with self._dataStructureLock:
|
||||
currentTopic = message.topic
|
||||
currentAction = self._parseTopicAction(currentTopic) # get/delete/update/delta
|
||||
currentType = self._parseTopicType(currentTopic) # accepted/rejected/delta
|
||||
payloadUTF8String = message.payload.decode('utf-8')
|
||||
# get/delete/update: Need to deal with token, timer and unsubscribe
|
||||
if currentAction in ["get", "delete", "update"]:
|
||||
# Check for token
|
||||
self._basicJSONParserHandler.setString(payloadUTF8String)
|
||||
if self._basicJSONParserHandler.validateJSON(): # Filter out invalid JSON
|
||||
currentToken = self._basicJSONParserHandler.getAttributeValue(u"clientToken")
|
||||
if currentToken is not None:
|
||||
self._logger.debug("shadow message clientToken: " + currentToken)
|
||||
if currentToken is not None and currentToken in self._tokenPool.keys(): # Filter out JSON without the desired token
|
||||
# Sync local version when it is an accepted response
|
||||
self._logger.debug("Token is in the pool. Type: " + currentType)
|
||||
if currentType == "accepted":
|
||||
incomingVersion = self._basicJSONParserHandler.getAttributeValue(u"version")
|
||||
# If it is get/update accepted response, we need to sync the local version
|
||||
if incomingVersion is not None and incomingVersion > self._lastVersionInSync and currentAction != "delete":
|
||||
self._lastVersionInSync = incomingVersion
|
||||
# If it is a delete accepted, we need to reset the version
|
||||
else:
|
||||
self._lastVersionInSync = -1 # The version will always be synced for the next incoming delta/GU-accepted response
|
||||
# Cancel the timer and clear the token
|
||||
self._tokenPool[currentToken].cancel()
|
||||
del self._tokenPool[currentToken]
|
||||
# Need to unsubscribe?
|
||||
self._shadowSubscribeStatusTable[currentAction] -= 1
|
||||
if not self._isPersistentSubscribe and self._shadowSubscribeStatusTable.get(currentAction) <= 0:
|
||||
self._shadowSubscribeStatusTable[currentAction] = 0
|
||||
processNonPersistentUnsubscribe = Thread(target=self._doNonPersistentUnsubscribe, args=[currentAction])
|
||||
processNonPersistentUnsubscribe.start()
|
||||
# Custom callback
|
||||
if self._shadowSubscribeCallbackTable.get(currentAction) is not None:
|
||||
processCustomCallback = Thread(target=self._shadowSubscribeCallbackTable[currentAction], args=[payloadUTF8String, currentType, currentToken])
|
||||
processCustomCallback.start()
|
||||
# delta: Watch for version
|
||||
else:
|
||||
currentType += "/" + self._parseTopicShadowName(currentTopic)
|
||||
# Sync local version
|
||||
self._basicJSONParserHandler.setString(payloadUTF8String)
|
||||
if self._basicJSONParserHandler.validateJSON(): # Filter out JSON without version
|
||||
incomingVersion = self._basicJSONParserHandler.getAttributeValue(u"version")
|
||||
if incomingVersion is not None and incomingVersion > self._lastVersionInSync:
|
||||
self._lastVersionInSync = incomingVersion
|
||||
# Custom callback
|
||||
if self._shadowSubscribeCallbackTable.get(currentAction) is not None:
|
||||
processCustomCallback = Thread(target=self._shadowSubscribeCallbackTable[currentAction], args=[payloadUTF8String, currentType, None])
|
||||
processCustomCallback.start()
|
||||
|
||||
def _parseTopicAction(self, srcTopic):
|
||||
ret = None
|
||||
fragments = srcTopic.split('/')
|
||||
if fragments[5] == "delta":
|
||||
ret = "delta"
|
||||
else:
|
||||
ret = fragments[4]
|
||||
return ret
|
||||
|
||||
def _parseTopicType(self, srcTopic):
|
||||
fragments = srcTopic.split('/')
|
||||
return fragments[5]
|
||||
|
||||
def _parseTopicShadowName(self, srcTopic):
|
||||
fragments = srcTopic.split('/')
|
||||
return fragments[2]
|
||||
|
||||
def _timerHandler(self, srcActionName, srcToken):
|
||||
with self._dataStructureLock:
|
||||
# Don't crash if we try to remove an unknown token
|
||||
if srcToken not in self._tokenPool:
|
||||
self._logger.warn('Tried to remove non-existent token from pool: %s' % str(srcToken))
|
||||
return
|
||||
# Remove the token
|
||||
del self._tokenPool[srcToken]
|
||||
# Need to unsubscribe?
|
||||
self._shadowSubscribeStatusTable[srcActionName] -= 1
|
||||
if not self._isPersistentSubscribe and self._shadowSubscribeStatusTable.get(srcActionName) <= 0:
|
||||
self._shadowSubscribeStatusTable[srcActionName] = 0
|
||||
self._shadowManagerHandler.basicShadowUnsubscribe(self._shadowName, srcActionName)
|
||||
# Notify time-out issue
|
||||
if self._shadowSubscribeCallbackTable.get(srcActionName) is not None:
|
||||
self._logger.info("Shadow request with token: " + str(srcToken) + " has timed out.")
|
||||
self._shadowSubscribeCallbackTable[srcActionName]("REQUEST TIME OUT", "timeout", srcToken)
|
||||
|
||||
def shadowGet(self, srcCallback, srcTimeout):
|
||||
"""
|
||||
**Description**
|
||||
|
||||
Retrieve the device shadow JSON document from AWS IoT by publishing an empty JSON document to the
|
||||
corresponding shadow topics. Shadow response topics will be subscribed to receive responses from
|
||||
AWS IoT regarding the result of the get operation. Retrieved shadow JSON document will be available
|
||||
in the registered callback. If no response is received within the provided timeout, a timeout
|
||||
notification will be passed into the registered callback.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Retrieve the shadow JSON document from AWS IoT, with a timeout set to 5 seconds
|
||||
BotShadow.shadowGet(customCallback, 5)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*srcCallback* - Function to be called when the response for this shadow request comes back. Should
|
||||
be in form :code:`customCallback(payload, responseStatus, token)`, where :code:`payload` is the
|
||||
JSON document returned, :code:`responseStatus` indicates whether the request has been accepted,
|
||||
rejected or is a delta message, :code:`token` is the token used for tracing in this request.
|
||||
|
||||
*srcTimeout* - Timeout to determine whether the request is invalid. When a request gets timeout,
|
||||
a timeout notification will be generated and put into the registered callback to notify users.
|
||||
|
||||
**Returns**
|
||||
|
||||
The token used for tracing in this shadow request.
|
||||
|
||||
"""
|
||||
with self._dataStructureLock:
|
||||
# Update callback data structure
|
||||
self._shadowSubscribeCallbackTable["get"] = srcCallback
|
||||
# Update number of pending feedback
|
||||
self._shadowSubscribeStatusTable["get"] += 1
|
||||
# clientToken
|
||||
currentToken = self._tokenHandler.getNextToken()
|
||||
self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["get", currentToken])
|
||||
self._basicJSONParserHandler.setString("{}")
|
||||
self._basicJSONParserHandler.validateJSON()
|
||||
self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken)
|
||||
currentPayload = self._basicJSONParserHandler.regenerateString()
|
||||
# Two subscriptions
|
||||
if not self._isPersistentSubscribe or not self._isGetSubscribed:
|
||||
self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "get", self.generalCallback)
|
||||
self._isGetSubscribed = True
|
||||
self._logger.info("Subscribed to get accepted/rejected topics for deviceShadow: " + self._shadowName)
|
||||
# One publish
|
||||
self._shadowManagerHandler.basicShadowPublish(self._shadowName, "get", currentPayload)
|
||||
# Start the timer
|
||||
self._tokenPool[currentToken].start()
|
||||
return currentToken
|
||||
|
||||
def shadowDelete(self, srcCallback, srcTimeout):
|
||||
"""
|
||||
**Description**
|
||||
|
||||
Delete the device shadow from AWS IoT by publishing an empty JSON document to the corresponding
|
||||
shadow topics. Shadow response topics will be subscribed to receive responses from AWS IoT
|
||||
regarding the result of the get operation. Responses will be available in the registered callback.
|
||||
If no response is received within the provided timeout, a timeout notification will be passed into
|
||||
the registered callback.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Delete the device shadow from AWS IoT, with a timeout set to 5 seconds
|
||||
BotShadow.shadowDelete(customCallback, 5)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*srcCallback* - Function to be called when the response for this shadow request comes back. Should
|
||||
be in form :code:`customCallback(payload, responseStatus, token)`, where :code:`payload` is the
|
||||
JSON document returned, :code:`responseStatus` indicates whether the request has been accepted,
|
||||
rejected or is a delta message, :code:`token` is the token used for tracing in this request.
|
||||
|
||||
*srcTimeout* - Timeout to determine whether the request is invalid. When a request gets timeout,
|
||||
a timeout notification will be generated and put into the registered callback to notify users.
|
||||
|
||||
**Returns**
|
||||
|
||||
The token used for tracing in this shadow request.
|
||||
|
||||
"""
|
||||
with self._dataStructureLock:
|
||||
# Update callback data structure
|
||||
self._shadowSubscribeCallbackTable["delete"] = srcCallback
|
||||
# Update number of pending feedback
|
||||
self._shadowSubscribeStatusTable["delete"] += 1
|
||||
# clientToken
|
||||
currentToken = self._tokenHandler.getNextToken()
|
||||
self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["delete", currentToken])
|
||||
self._basicJSONParserHandler.setString("{}")
|
||||
self._basicJSONParserHandler.validateJSON()
|
||||
self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken)
|
||||
currentPayload = self._basicJSONParserHandler.regenerateString()
|
||||
# Two subscriptions
|
||||
if not self._isPersistentSubscribe or not self._isDeleteSubscribed:
|
||||
self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "delete", self.generalCallback)
|
||||
self._isDeleteSubscribed = True
|
||||
self._logger.info("Subscribed to delete accepted/rejected topics for deviceShadow: " + self._shadowName)
|
||||
# One publish
|
||||
self._shadowManagerHandler.basicShadowPublish(self._shadowName, "delete", currentPayload)
|
||||
# Start the timer
|
||||
self._tokenPool[currentToken].start()
|
||||
return currentToken
|
||||
|
||||
def shadowUpdate(self, srcJSONPayload, srcCallback, srcTimeout):
|
||||
"""
|
||||
**Description**
|
||||
|
||||
Update the device shadow JSON document string from AWS IoT by publishing the provided JSON
|
||||
document to the corresponding shadow topics. Shadow response topics will be subscribed to
|
||||
receive responses from AWS IoT regarding the result of the get operation. Response will be
|
||||
available in the registered callback. If no response is received within the provided timeout,
|
||||
a timeout notification will be passed into the registered callback.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Update the shadow JSON document from AWS IoT, with a timeout set to 5 seconds
|
||||
BotShadow.shadowUpdate(newShadowJSONDocumentString, customCallback, 5)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*srcJSONPayload* - JSON document string used to update shadow JSON document in AWS IoT.
|
||||
|
||||
*srcCallback* - Function to be called when the response for this shadow request comes back. Should
|
||||
be in form :code:`customCallback(payload, responseStatus, token)`, where :code:`payload` is the
|
||||
JSON document returned, :code:`responseStatus` indicates whether the request has been accepted,
|
||||
rejected or is a delta message, :code:`token` is the token used for tracing in this request.
|
||||
|
||||
*srcTimeout* - Timeout to determine whether the request is invalid. When a request gets timeout,
|
||||
a timeout notification will be generated and put into the registered callback to notify users.
|
||||
|
||||
**Returns**
|
||||
|
||||
The token used for tracing in this shadow request.
|
||||
|
||||
"""
|
||||
# Validate JSON
|
||||
self._basicJSONParserHandler.setString(srcJSONPayload)
|
||||
if self._basicJSONParserHandler.validateJSON():
|
||||
with self._dataStructureLock:
|
||||
# clientToken
|
||||
currentToken = self._tokenHandler.getNextToken()
|
||||
self._tokenPool[currentToken] = Timer(srcTimeout, self._timerHandler, ["update", currentToken])
|
||||
self._basicJSONParserHandler.setAttributeValue("clientToken", currentToken)
|
||||
JSONPayloadWithToken = self._basicJSONParserHandler.regenerateString()
|
||||
# Update callback data structure
|
||||
self._shadowSubscribeCallbackTable["update"] = srcCallback
|
||||
# Update number of pending feedback
|
||||
self._shadowSubscribeStatusTable["update"] += 1
|
||||
# Two subscriptions
|
||||
if not self._isPersistentSubscribe or not self._isUpdateSubscribed:
|
||||
self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "update", self.generalCallback)
|
||||
self._isUpdateSubscribed = True
|
||||
self._logger.info("Subscribed to update accepted/rejected topics for deviceShadow: " + self._shadowName)
|
||||
# One publish
|
||||
self._shadowManagerHandler.basicShadowPublish(self._shadowName, "update", JSONPayloadWithToken)
|
||||
# Start the timer
|
||||
self._tokenPool[currentToken].start()
|
||||
else:
|
||||
raise ValueError("Invalid JSON file.")
|
||||
return currentToken
|
||||
|
||||
def shadowRegisterDeltaCallback(self, srcCallback):
|
||||
"""
|
||||
**Description**
|
||||
|
||||
Listen on delta topics for this device shadow by subscribing to delta topics. Whenever there
|
||||
is a difference between the desired and reported state, the registered callback will be called
|
||||
and the delta payload will be available in the callback.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Listen on delta topics for BotShadow
|
||||
BotShadow.shadowRegisterDeltaCallback(customCallback)
|
||||
|
||||
**Parameters**
|
||||
|
||||
*srcCallback* - Function to be called when the response for this shadow request comes back. Should
|
||||
be in form :code:`customCallback(payload, responseStatus, token)`, where :code:`payload` is the
|
||||
JSON document returned, :code:`responseStatus` indicates whether the request has been accepted,
|
||||
rejected or is a delta message, :code:`token` is the token used for tracing in this request.
|
||||
|
||||
**Returns**
|
||||
|
||||
None
|
||||
|
||||
"""
|
||||
with self._dataStructureLock:
|
||||
# Update callback data structure
|
||||
self._shadowSubscribeCallbackTable["delta"] = srcCallback
|
||||
# One subscription
|
||||
self._shadowManagerHandler.basicShadowSubscribe(self._shadowName, "delta", self.generalCallback)
|
||||
self._logger.info("Subscribed to delta topic for deviceShadow: " + self._shadowName)
|
||||
|
||||
def shadowUnregisterDeltaCallback(self):
|
||||
"""
|
||||
**Description**
|
||||
|
||||
Cancel listening on delta topics for this device shadow by unsubscribing to delta topics. There will
|
||||
be no delta messages received after this API call even though there is a difference between the
|
||||
desired and reported state.
|
||||
|
||||
**Syntax**
|
||||
|
||||
.. code:: python
|
||||
|
||||
# Cancel listening on delta topics for BotShadow
|
||||
BotShadow.shadowUnregisterDeltaCallback()
|
||||
|
||||
**Parameters**
|
||||
|
||||
None
|
||||
|
||||
**Returns**
|
||||
|
||||
None
|
||||
|
||||
"""
|
||||
with self._dataStructureLock:
|
||||
# Update callback data structure
|
||||
del self._shadowSubscribeCallbackTable["delta"]
|
||||
# One unsubscription
|
||||
self._shadowManagerHandler.basicShadowUnsubscribe(self._shadowName, "delta")
|
||||
self._logger.info("Unsubscribed to delta topics for deviceShadow: " + self._shadowName)
|
||||
@@ -0,0 +1,83 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import logging
|
||||
import time
|
||||
from threading import Lock
|
||||
|
||||
class _shadowAction:
|
||||
_actionType = ["get", "update", "delete", "delta"]
|
||||
|
||||
def __init__(self, srcShadowName, srcActionName):
|
||||
if srcActionName is None or srcActionName not in self._actionType:
|
||||
raise TypeError("Unsupported shadow action.")
|
||||
self._shadowName = srcShadowName
|
||||
self._actionName = srcActionName
|
||||
self.isDelta = srcActionName == "delta"
|
||||
if self.isDelta:
|
||||
self._topicDelta = "$aws/things/" + str(self._shadowName) + "/shadow/update/delta"
|
||||
else:
|
||||
self._topicGeneral = "$aws/things/" + str(self._shadowName) + "/shadow/" + str(self._actionName)
|
||||
self._topicAccept = "$aws/things/" + str(self._shadowName) + "/shadow/" + str(self._actionName) + "/accepted"
|
||||
self._topicReject = "$aws/things/" + str(self._shadowName) + "/shadow/" + str(self._actionName) + "/rejected"
|
||||
|
||||
def getTopicGeneral(self):
|
||||
return self._topicGeneral
|
||||
|
||||
def getTopicAccept(self):
|
||||
return self._topicAccept
|
||||
|
||||
def getTopicReject(self):
|
||||
return self._topicReject
|
||||
|
||||
def getTopicDelta(self):
|
||||
return self._topicDelta
|
||||
|
||||
|
||||
class shadowManager:
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, srcMQTTCore):
|
||||
# Load in mqttCore
|
||||
if srcMQTTCore is None:
|
||||
raise TypeError("None type inputs detected.")
|
||||
self._mqttCoreHandler = srcMQTTCore
|
||||
self._shadowSubUnsubOperationLock = Lock()
|
||||
|
||||
def basicShadowPublish(self, srcShadowName, srcShadowAction, srcPayload):
|
||||
currentShadowAction = _shadowAction(srcShadowName, srcShadowAction)
|
||||
self._mqttCoreHandler.publish(currentShadowAction.getTopicGeneral(), srcPayload, 0, False)
|
||||
|
||||
def basicShadowSubscribe(self, srcShadowName, srcShadowAction, srcCallback):
|
||||
with self._shadowSubUnsubOperationLock:
|
||||
currentShadowAction = _shadowAction(srcShadowName, srcShadowAction)
|
||||
if currentShadowAction.isDelta:
|
||||
self._mqttCoreHandler.subscribe(currentShadowAction.getTopicDelta(), 0, srcCallback)
|
||||
else:
|
||||
self._mqttCoreHandler.subscribe(currentShadowAction.getTopicAccept(), 0, srcCallback)
|
||||
self._mqttCoreHandler.subscribe(currentShadowAction.getTopicReject(), 0, srcCallback)
|
||||
time.sleep(2)
|
||||
|
||||
def basicShadowUnsubscribe(self, srcShadowName, srcShadowAction):
|
||||
with self._shadowSubUnsubOperationLock:
|
||||
currentShadowAction = _shadowAction(srcShadowName, srcShadowAction)
|
||||
if currentShadowAction.isDelta:
|
||||
self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicDelta())
|
||||
else:
|
||||
self._logger.debug(currentShadowAction.getTopicAccept())
|
||||
self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicAccept())
|
||||
self._logger.debug(currentShadowAction.getTopicReject())
|
||||
self._mqttCoreHandler.unsubscribe(currentShadowAction.getTopicReject())
|
||||
@@ -0,0 +1,19 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
|
||||
class DropBehaviorTypes(object):
|
||||
DROP_OLDEST = 0
|
||||
DROP_NEWEST = 1
|
||||
@@ -0,0 +1,92 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
|
||||
class CredentialsProvider(object):
|
||||
|
||||
def __init__(self):
|
||||
self._ca_path = ""
|
||||
|
||||
def set_ca_path(self, ca_path):
|
||||
self._ca_path = ca_path
|
||||
|
||||
def get_ca_path(self):
|
||||
return self._ca_path
|
||||
|
||||
|
||||
class CertificateCredentialsProvider(CredentialsProvider):
|
||||
|
||||
def __init__(self):
|
||||
CredentialsProvider.__init__(self)
|
||||
self._cert_path = ""
|
||||
self._key_path = ""
|
||||
|
||||
def set_cert_path(self,cert_path):
|
||||
self._cert_path = cert_path
|
||||
|
||||
def set_key_path(self, key_path):
|
||||
self._key_path = key_path
|
||||
|
||||
def get_cert_path(self):
|
||||
return self._cert_path
|
||||
|
||||
def get_key_path(self):
|
||||
return self._key_path
|
||||
|
||||
|
||||
class IAMCredentialsProvider(CredentialsProvider):
|
||||
|
||||
def __init__(self):
|
||||
CredentialsProvider.__init__(self)
|
||||
self._aws_access_key_id = ""
|
||||
self._aws_secret_access_key = ""
|
||||
self._aws_session_token = ""
|
||||
|
||||
def set_access_key_id(self, access_key_id):
|
||||
self._aws_access_key_id = access_key_id
|
||||
|
||||
def set_secret_access_key(self, secret_access_key):
|
||||
self._aws_secret_access_key = secret_access_key
|
||||
|
||||
def set_session_token(self, session_token):
|
||||
self._aws_session_token = session_token
|
||||
|
||||
def get_access_key_id(self):
|
||||
return self._aws_access_key_id
|
||||
|
||||
def get_secret_access_key(self):
|
||||
return self._aws_secret_access_key
|
||||
|
||||
def get_session_token(self):
|
||||
return self._aws_session_token
|
||||
|
||||
|
||||
class EndpointProvider(object):
|
||||
|
||||
def __init__(self):
|
||||
self._host = ""
|
||||
self._port = -1
|
||||
|
||||
def set_host(self, host):
|
||||
self._host = host
|
||||
|
||||
def set_port(self, port):
|
||||
self._port = port
|
||||
|
||||
def get_host(self):
|
||||
return self._host
|
||||
|
||||
def get_port(self):
|
||||
return self._port
|
||||
@@ -0,0 +1,153 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
import AWSIoTPythonSDK.exception.operationTimeoutException as operationTimeoutException
|
||||
import AWSIoTPythonSDK.exception.operationError as operationError
|
||||
|
||||
|
||||
# Serial Exception
|
||||
class acceptTimeoutException(Exception):
|
||||
def __init__(self, msg="Accept Timeout"):
|
||||
self.message = msg
|
||||
|
||||
|
||||
# MQTT Operation Timeout Exception
|
||||
class connectTimeoutException(operationTimeoutException.operationTimeoutException):
|
||||
def __init__(self, msg="Connect Timeout"):
|
||||
self.message = msg
|
||||
|
||||
|
||||
class disconnectTimeoutException(operationTimeoutException.operationTimeoutException):
|
||||
def __init__(self, msg="Disconnect Timeout"):
|
||||
self.message = msg
|
||||
|
||||
|
||||
class publishTimeoutException(operationTimeoutException.operationTimeoutException):
|
||||
def __init__(self, msg="Publish Timeout"):
|
||||
self.message = msg
|
||||
|
||||
|
||||
class subscribeTimeoutException(operationTimeoutException.operationTimeoutException):
|
||||
def __init__(self, msg="Subscribe Timeout"):
|
||||
self.message = msg
|
||||
|
||||
|
||||
class unsubscribeTimeoutException(operationTimeoutException.operationTimeoutException):
|
||||
def __init__(self, msg="Unsubscribe Timeout"):
|
||||
self.message = msg
|
||||
|
||||
|
||||
# MQTT Operation Error
|
||||
class connectError(operationError.operationError):
|
||||
def __init__(self, errorCode):
|
||||
self.message = "Connect Error: " + str(errorCode)
|
||||
|
||||
|
||||
class disconnectError(operationError.operationError):
|
||||
def __init__(self, errorCode):
|
||||
self.message = "Disconnect Error: " + str(errorCode)
|
||||
|
||||
|
||||
class publishError(operationError.operationError):
|
||||
def __init__(self, errorCode):
|
||||
self.message = "Publish Error: " + str(errorCode)
|
||||
|
||||
|
||||
class publishQueueFullException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Internal Publish Queue Full"
|
||||
|
||||
|
||||
class publishQueueDisabledException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Offline publish request dropped because queueing is disabled"
|
||||
|
||||
|
||||
class subscribeError(operationError.operationError):
|
||||
def __init__(self, errorCode):
|
||||
self.message = "Subscribe Error: " + str(errorCode)
|
||||
|
||||
|
||||
class subscribeQueueFullException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Internal Subscribe Queue Full"
|
||||
|
||||
|
||||
class subscribeQueueDisabledException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Offline subscribe request dropped because queueing is disabled"
|
||||
|
||||
|
||||
class unsubscribeError(operationError.operationError):
|
||||
def __init__(self, errorCode):
|
||||
self.message = "Unsubscribe Error: " + str(errorCode)
|
||||
|
||||
|
||||
class unsubscribeQueueFullException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Internal Unsubscribe Queue Full"
|
||||
|
||||
|
||||
class unsubscribeQueueDisabledException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Offline unsubscribe request dropped because queueing is disabled"
|
||||
|
||||
|
||||
# Websocket Error
|
||||
class wssNoKeyInEnvironmentError(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "No AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY detected in $ENV."
|
||||
|
||||
|
||||
class wssHandShakeError(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Error in WSS handshake."
|
||||
|
||||
|
||||
# Greengrass Discovery Error
|
||||
class DiscoveryDataNotFoundException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "No discovery data found"
|
||||
|
||||
|
||||
class DiscoveryTimeoutException(operationTimeoutException.operationTimeoutException):
|
||||
def __init__(self, message="Discovery request timed out"):
|
||||
self.message = message
|
||||
|
||||
|
||||
class DiscoveryInvalidRequestException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Invalid discovery request"
|
||||
|
||||
|
||||
class DiscoveryUnauthorizedException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Discovery request not authorized"
|
||||
|
||||
|
||||
class DiscoveryThrottlingException(operationError.operationError):
|
||||
def __init__(self):
|
||||
self.message = "Too many discovery requests"
|
||||
|
||||
|
||||
class DiscoveryFailure(operationError.operationError):
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
|
||||
# Client Error
|
||||
class ClientError(Exception):
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
@@ -0,0 +1,19 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
|
||||
class operationError(Exception):
|
||||
def __init__(self, msg="Operation Error"):
|
||||
self.message = msg
|
||||
@@ -0,0 +1,19 @@
|
||||
# /*
|
||||
# * Copyright 2010-2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
# *
|
||||
# * Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# * You may not use this file except in compliance with the License.
|
||||
# * A copy of the License is located at
|
||||
# *
|
||||
# * http://aws.amazon.com/apache2.0
|
||||
# *
|
||||
# * or in the "license" file accompanying this file. This file is distributed
|
||||
# * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
# * express or implied. See the License for the specific language governing
|
||||
# * permissions and limitations under the License.
|
||||
# */
|
||||
|
||||
|
||||
class operationTimeoutException(Exception):
|
||||
def __init__(self, msg="Operation Timeout"):
|
||||
self.message = msg
|
||||
2
aws-iot-device-sdk-python/setup.cfg
Normal file
2
aws-iot-device-sdk-python/setup.cfg
Normal file
@@ -0,0 +1,2 @@
|
||||
[metadata]
|
||||
description-file = README.rst
|
||||
34
aws-iot-device-sdk-python/setup.py
Normal file
34
aws-iot-device-sdk-python/setup.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import sys
|
||||
sys.path.insert(0, 'AWSIoTPythonSDK')
|
||||
import AWSIoTPythonSDK
|
||||
currentVersion = AWSIoTPythonSDK.__version__
|
||||
|
||||
from distutils.core import setup
|
||||
setup(
|
||||
name = 'AWSIoTPythonSDK',
|
||||
packages=['AWSIoTPythonSDK', 'AWSIoTPythonSDK.core',
|
||||
'AWSIoTPythonSDK.core.util', 'AWSIoTPythonSDK.core.shadow', 'AWSIoTPythonSDK.core.protocol',
|
||||
'AWSIoTPythonSDK.core.jobs',
|
||||
'AWSIoTPythonSDK.core.protocol.paho', 'AWSIoTPythonSDK.core.protocol.internal',
|
||||
'AWSIoTPythonSDK.core.protocol.connection', 'AWSIoTPythonSDK.core.greengrass',
|
||||
'AWSIoTPythonSDK.core.greengrass.discovery', 'AWSIoTPythonSDK.exception'],
|
||||
version = currentVersion,
|
||||
description = 'SDK for connecting to AWS IoT using Python.',
|
||||
author = 'Amazon Web Service',
|
||||
author_email = '',
|
||||
url = 'https://github.com/aws/aws-iot-device-sdk-python.git',
|
||||
download_url = 'https://s3.amazonaws.com/aws-iot-device-sdk-python/aws-iot-device-sdk-python-latest.zip',
|
||||
keywords = ['aws', 'iot', 'mqtt'],
|
||||
classifiers = [
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"Intended Audience :: Developers",
|
||||
"Natural Language :: English",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Python :: 2.7",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.3",
|
||||
"Programming Language :: Python :: 3.4",
|
||||
"Programming Language :: Python :: 3.5"
|
||||
]
|
||||
)
|
||||
Reference in New Issue
Block a user