Files
DataLogger-Generic/web_db/mysql-connector-python-2.1.4/tests/test_network.py
2016-11-01 16:56:22 -05:00

428 lines
15 KiB
Python

# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2012, 2016, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unittests for mysql.connector.network
"""
import os
import socket
import logging
from collections import deque
import unittest
import tests
from mysql.connector import (network, errors, constants)
LOGGER = logging.getLogger(tests.LOGGER_NAME)
class NetworkTests(tests.MySQLConnectorTests):
"""Testing mysql.connector.network functions"""
def test__prepare_packets(self):
"""Prepare packets for sending"""
data = (b'abcdefghijklmn', 1)
exp = [b'\x0e\x00\x00\x01abcdefghijklmn']
self.assertEqual(exp, network._prepare_packets(*(data)))
data = (b'a' * (constants.MAX_PACKET_LENGTH + 1000), 2)
exp = [
b'\xff\xff\xff\x02' + (b'a' * constants.MAX_PACKET_LENGTH),
b'\xe8\x03\x00\x03' + (b'a' * 1000)
]
self.assertEqual(exp, network._prepare_packets(*(data)))
class BaseMySQLSocketTests(tests.MySQLConnectorTests):
"""Testing mysql.connector.network.BaseMySQLSocket"""
def setUp(self):
config = tests.get_mysql_config()
self._host = config['host']
self._port = config['port']
self.cnx = network.BaseMySQLSocket()
def tearDown(self):
try:
self.cnx.close_connection()
except:
pass
def _get_socket(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
LOGGER.debug("Get socket for {host}:{port}".format(
host=self._host, port=self._port))
sock.connect((self._host, self._port))
return sock
def test_init(self):
"""MySQLSocket initialization"""
exp = {
'sock': None,
'_connection_timeout': None,
'_packet_queue': deque(),
'recvsize': 1024 * 8,
}
for key, value in exp.items():
self.assertEqual(value, self.cnx.__dict__[key])
def test_next_packet_number(self):
"""Test packet number property"""
self.assertEqual(0, self.cnx.next_packet_number)
self.assertEqual(0, self.cnx._packet_number)
self.assertEqual(1, self.cnx.next_packet_number)
self.assertEqual(1, self.cnx._packet_number)
self.cnx._packet_number = 255
self.assertEqual(0, self.cnx.next_packet_number)
def test_open_connection(self):
"""Opening a connection"""
self.assertRaises(NotImplementedError, self.cnx.open_connection)
def test_get_address(self):
"""Get the address of a connection"""
self.assertRaises(NotImplementedError, self.cnx.get_address)
def test_shutdown(self):
"""Shutting down a connection"""
self.cnx.shutdown()
self.assertEqual(None, self.cnx.sock)
def test_close_connection(self):
"""Closing a connection"""
self.cnx.close_connection()
self.assertEqual(None, self.cnx.sock)
def test_send_plain(self):
"""Send plain data through the socket"""
data = b'asddfasdfasdf'
self.assertRaises(errors.OperationalError, self.cnx.send_plain,
data, 0)
self.cnx.sock = tests.DummySocket()
data = [
(b'\x03\x53\x45\x4c\x45\x43\x54\x20\x22\x61\x62\x63\x22', 1),
(b'\x03\x53\x45\x4c\x45\x43\x54\x20\x22'
+ (b'\x61' * (constants.MAX_PACKET_LENGTH + 1000)) + b'\x22', 2)]
self.assertRaises(Exception, self.cnx.send_plain, None, None)
for value in data:
exp = network._prepare_packets(*value)
try:
self.cnx.send_plain(*value)
except errors.Error as err:
self.fail("Failed sending pktnr {}: {}".format(value[1],
str(err)))
self.assertEqual(exp, self.cnx.sock._client_sends)
self.cnx.sock.reset()
def test_send_compressed(self):
"""Send compressed data through the socket"""
data = b'asddfasdfasdf'
self.assertRaises(errors.OperationalError, self.cnx.send_compressed,
data, 0)
self.cnx.sock = tests.DummySocket()
self.assertRaises(Exception, self.cnx.send_compressed, None, None)
# Small packet
data = (b'\x03\x53\x45\x4c\x45\x43\x54\x20\x22\x61\x62\x63\x22', 1)
exp = [b'\x11\x00\x00\x02\x00\x00\x00\r\x00\x00\x01\x03SELECT "abc"']
try:
self.cnx.send_compressed(*data)
except errors.Error as err:
self.fail("Failed sending pktnr {}: {}".format(data[1], err))
self.assertEqual(exp, self.cnx.sock._client_sends)
self.cnx.sock.reset()
# Slightly bigger packet (not getting compressed)
data = (b'\x03\x53\x45\x4c\x45\x43\x54\x20\x22\x61\x62\x63\x22', 1)
exp = (24, b'\x11\x00\x00\x03\x00\x00\x00\x0d\x00\x00\x01\x03'
b'\x53\x45\x4c\x45\x43\x54\x20\x22')
try:
self.cnx.send_compressed(*data)
except errors.Error as err:
self.fail("Failed sending pktnr {}: {}".format(data[1], str(err)))
received = self.cnx.sock._client_sends[0]
self.assertEqual(exp, (len(received), received[:20]))
self.cnx.sock.reset()
# Big packet
data = (b'\x03\x53\x45\x4c\x45\x43\x54\x20\x22'
+ b'\x61' * (constants.MAX_PACKET_LENGTH + 1000) + b'\x22', 2)
exp = [
(63, b'\x38\x00\x00\x04\x00\x40\x00\x78\x9c\xed\xc1\x31'
b'\x0d\x00\x20\x0c\x00\xb0\x04\x8c'),
(16322, b'\xbb\x3f\x00\x05\xf9\xc3\xff\x78\x9c\xec\xc1\x81'
b'\x00\x00\x00\x00\x80\x20\xd6\xfd')]
try:
self.cnx.send_compressed(*data)
except errors.Error as err:
self.fail("Failed sending pktnr {}: {}".format(data[1], str(err)))
received = [(len(r), r[:20]) for r in self.cnx.sock._client_sends]
self.assertEqual(exp, received)
self.cnx.sock.reset()
def test_recv_plain(self):
"""Receive data from the socket"""
self.cnx.sock = tests.DummySocket()
def get_address():
return 'dummy'
self.cnx.get_address = get_address
# Receive a packet which is not 4 bytes long
self.cnx.sock.add_packet(b'\01\01\01')
self.assertRaises(errors.InterfaceError, self.cnx.recv_plain)
# Socket fails to receive and produces an error
self.cnx.sock.raise_socket_error()
self.assertRaises(errors.OperationalError, self.cnx.recv_plain)
# Receive packets after a query, SELECT "Ham"
exp = [
b'\x01\x00\x00\x01\x01',
b'\x19\x00\x00\x02\x03\x64\x65\x66\x00\x00\x00\x03\x48\x61\x6d\x00'
b'\x0c\x21\x00\x09\x00\x00\x00\xfd\x01\x00\x1f\x00\x00',
b'\x05\x00\x00\x03\xfe\x00\x00\x02\x00',
b'\x04\x00\x00\x04\x03\x48\x61\x6d',
b'\x05\x00\x00\x05\xfe\x00\x00\x02\x00',
]
self.cnx.sock.reset()
self.cnx.sock.add_packets(exp)
length_exp = len(exp)
result = []
packet = self.cnx.recv_plain()
while packet:
result.append(packet)
if length_exp == len(result):
break
packet = self.cnx.recv_plain()
self.assertEqual(exp, result)
def test_recv_compressed(self):
"""Receive compressed data from the socket"""
self.cnx.sock = tests.DummySocket()
def get_address():
return 'dummy'
self.cnx.get_address = get_address
# Receive a packet which is not 7 bytes long
self.cnx.sock.add_packet(b'\01\01\01\01\01\01')
self.assertRaises(errors.InterfaceError, self.cnx.recv_compressed)
# Receive the header of a packet, but nothing more
self.cnx.sock.add_packet(b'\01\00\00\00\00\00\00')
self.assertRaises(errors.InterfaceError, self.cnx.recv_compressed)
# Socket fails to receive and produces an error
self.cnx.sock.raise_socket_error()
self.assertRaises(errors.OperationalError, self.cnx.recv_compressed)
def test_set_connection_timeout(self):
"""Set the connection timeout"""
exp = 5
self.cnx.set_connection_timeout(exp)
self.assertEqual(exp, self.cnx._connection_timeout)
@unittest.skipIf(os.name == 'nt', "Skip UNIX Socket tests on Windows")
class MySQLUnixSocketTests(tests.MySQLConnectorTests):
"""Testing mysql.connector.network.MySQLUnixSocket"""
def setUp(self):
config = tests.get_mysql_config()
self._unix_socket = config['unix_socket']
self.cnx = network.MySQLUnixSocket(unix_socket=config['unix_socket'])
def tearDown(self):
try:
self.cnx.close_connection()
except:
pass
def test_init(self):
"""MySQLUnixSocket initialization"""
exp = {
'unix_socket': self._unix_socket,
}
for key, value in exp.items():
self.assertEqual(value, self.cnx.__dict__[key])
def test_get_address(self):
"""Get path to the Unix socket"""
exp = self._unix_socket
self.assertEqual(exp, self.cnx.get_address())
def test_open_connection(self):
"""Open a connection using a Unix socket"""
if os.name == 'nt':
self.assertRaises(errors.InterfaceError, self.cnx.open_connection)
else:
try:
self.cnx.open_connection()
except errors.Error as err:
self.fail(str(err))
@unittest.skipIf(not tests.SSL_AVAILABLE,
"Could not test switch to SSL. Make sure Python supports "
"SSL.")
def test_switch_to_ssl(self):
"""Switch the socket to use SSL"""
args = {
'ca': os.path.join(tests.SSL_DIR, 'tests_CA_cert.pem'),
'cert': os.path.join(tests.SSL_DIR, 'tests_client_cert.pem'),
'key': os.path.join(tests.SSL_DIR, 'tests_client_key.pem'),
'cipher': 'AES256-SHA'
}
self.assertRaises(errors.InterfaceError,
self.cnx.switch_to_ssl, **args)
# Handshake failure
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(4)
sock.connect(self._unix_socket)
self.cnx.sock = sock
self.assertRaises(errors.InterfaceError,
self.cnx.switch_to_ssl, **args)
class MySQLTCPSocketTests(tests.MySQLConnectorTests):
"""Testing mysql.connector.network..MySQLTCPSocket"""
def setUp(self):
config = tests.get_mysql_config()
self._host = config['host']
self._port = config['port']
self.cnx = network.MySQLTCPSocket(host=self._host, port=self._port)
def tearDown(self):
try:
self.cnx.close_connection()
except:
pass
def test_init(self):
"""MySQLTCPSocket initialization"""
exp = {
'server_host': self._host,
'server_port': self._port,
}
for key, value in exp.items():
self.assertEqual(value, self.cnx.__dict__[key])
def test_get_address(self):
"""Get TCP/IP address"""
exp = "%s:%s" % (self._host, self._port)
self.assertEqual(exp, self.cnx.get_address())
@unittest.skipIf(tests.IPV6_AVAILABLE, "Testing IPv6, not testing IPv4")
def test_open_connection__ipv4(self):
"""Open a connection using TCP"""
try:
self.cnx.open_connection()
except errors.Error as err:
self.fail(str(err))
config = tests.get_mysql_config()
self._host = config['host']
self._port = config['port']
cases = [
# Address, Expected Family, Should Raise, Force IPv6
(tests.get_mysql_config()['host'], socket.AF_INET, False, False),
]
for case in cases:
self._test_open_connection(*case)
@unittest.skipIf(not tests.IPV6_AVAILABLE, "IPv6 testing disabled")
def test_open_connection__ipv6(self):
"""Open a connection using TCP"""
config = tests.get_mysql_config()
self._host = config['host']
self._port = config['port']
cases = [
# Address, Expected Family, Should Raise, Force IPv6
('::1', socket.AF_INET6, False, False),
('2001::14:06:77', socket.AF_INET6, True, False),
('xx:00:xx', socket.AF_INET6, True, False),
]
for case in cases:
self._test_open_connection(*case)
def _test_open_connection(self, addr, family, should_raise, force):
try:
sock = network.MySQLTCPSocket(host=addr,
port=self._port,
force_ipv6=force)
sock.set_connection_timeout(1)
sock.open_connection()
except (errors.InterfaceError, socket.error):
if not should_raise:
self.fail('{0} incorrectly raised socket.error'.format(
addr))
else:
if should_raise:
self.fail('{0} should have raised socket.error'.format(
addr))
else:
self.assertEqual(family, sock._family,
"Family for {0} did not match".format(
addr, family, sock._family))
sock.close_connection()
@unittest.skipIf(not tests.SSL_AVAILABLE,
"Could not test switch to SSL. Make sure Python supports "
"SSL.")
def test_switch_to_ssl(self):
"""Switch the socket to use SSL"""
args = {
'ca': os.path.join(tests.SSL_DIR, 'tests_CA_cert.pem'),
'cert': os.path.join(tests.SSL_DIR, 'tests_client_cert.pem'),
'key': os.path.join(tests.SSL_DIR, 'tests_client_key.pem'),
}
self.assertRaises(errors.InterfaceError,
self.cnx.switch_to_ssl, **args)
# Handshake failure
(family, socktype, proto, _,
sockaddr) = socket.getaddrinfo(self._host, self._port)[0]
sock = socket.socket(family, socktype, proto)
sock.settimeout(4)
sock.connect(sockaddr)
self.cnx.sock = sock
self.assertRaises(errors.InterfaceError,
self.cnx.switch_to_ssl, **args)