Removes unnecessary files and sets allowed files

This commit is contained in:
Patrick McDonagh
2016-11-04 17:50:53 -05:00
parent cf98beae33
commit ab005f9cce
102 changed files with 1 additions and 56983 deletions

View File

@@ -8,7 +8,7 @@ from sqlalchemy import and_
import mysql.connector
UPLOAD_FOLDER = '/root/tag-server/flask/app/docs'
ALLOWED_EXTENSIONS = set(['txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'])
ALLOWED_EXTENSIONS = set(['txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif', 'doc', 'docx', 'xls', 'xlsx', 'zip'])
app = Flask('app', static_url_path='')
app.config.update(

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 41 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 17 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.0 KiB

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,856 +0,0 @@
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unittests
"""
import os
import sys
import re
import socket
import datetime
import inspect
import platform
import unittest
import logging
import shutil
import subprocess
import errno
import traceback
from imp import load_source
from functools import wraps
from pkgutil import walk_packages
LOGGER_NAME = "myconnpy_tests"
LOGGER = logging.getLogger(LOGGER_NAME)
PY2 = sys.version_info[0] == 2
_CACHED_TESTCASES = []
try:
from unittest.util import strclass
except ImportError:
# Python v2
from unittest import _strclass as strclass # pylint: disable=E0611
try:
from unittest.case import SkipTest
except ImportError:
if sys.version_info[0:2] == (3, 1):
from unittest import SkipTest
elif sys.version_info[0:2] == (2, 6):
# Support skipping tests for Python v2.6
from tests.py26 import test_skip, test_skip_if, SkipTest
unittest.skip = test_skip
unittest.skipIf = test_skip_if
else:
LOGGER.error("Could not initialize Python's unittest module")
sys.exit(1)
from lib.cpy_distutils import get_mysql_config_info
SSL_AVAILABLE = True
try:
import ssl
except ImportError:
SSL_AVAILABLE = False
# Note that IPv6 support for Python is checked here, but it can be disabled
# when the bind_address of MySQL was not set to '::1'.
IPV6_AVAILABLE = socket.has_ipv6
OLD_UNITTEST = sys.version_info[0:2] in [(2, 6)]
if os.name == 'nt':
WINDOWS_VERSION = platform.win32_ver()[1]
WINDOWS_VERSION_INFO = [0] * 2
for i, value in enumerate(WINDOWS_VERSION.split('.')[0:2]):
WINDOWS_VERSION_INFO[i] = int(value)
WINDOWS_VERSION_INFO = tuple(WINDOWS_VERSION_INFO)
else:
WINDOWS_VERSION = None
WINDOWS_VERSION_INFO = ()
# Following dictionary holds messages which were added by test cases
# but only logged at the end.
MESSAGES = {
'WARNINGS': [],
'INFO': [],
'SKIPPED': [],
}
OPTIONS_INIT = False
MYSQL_SERVERS_NEEDED = 1
MYSQL_SERVERS = []
MYSQL_VERSION = ()
MYSQL_VERSION_TXT = ''
MYSQL_DUMMY = None
MYSQL_DUMMY_THREAD = None
SSL_DIR = os.path.join('tests', 'data', 'ssl')
SSL_CA = os.path.abspath(os.path.join(SSL_DIR, 'tests_CA_cert.pem'))
SSL_CERT = os.path.abspath(os.path.join(SSL_DIR, 'tests_client_cert.pem'))
SSL_KEY = os.path.abspath(os.path.join(SSL_DIR, 'tests_client_key.pem'))
TEST_BUILD_DIR = None
MYSQL_CAPI = None
DJANGO_VERSION = None
FABRIC_CONFIG = None
__all__ = [
'MySQLConnectorTests',
'get_test_names', 'printmsg',
'LOGGER_NAME',
'DummySocket',
'SSL_DIR',
'get_test_modules',
'MESSAGES',
'setup_logger',
'install_connector',
'TEST_BUILD_DIR',
]
class DummySocket(object):
"""Dummy socket class
This class helps to test socket connection without actually making any
network activity. It is a proxy class using socket.socket.
"""
def __init__(self, *args):
self._socket = socket.socket(*args)
self._server_replies = bytearray(b'')
self._client_sends = []
self._raise_socket_error = 0
def __getattr__(self, attr):
return getattr(self._socket, attr)
def raise_socket_error(self, err=errno.EPERM):
self._raise_socket_error = err
def recv(self, bufsize=4096, flags=0):
if self._raise_socket_error:
raise socket.error(self._raise_socket_error)
res = self._server_replies[0:bufsize]
self._server_replies = self._server_replies[bufsize:]
return res
def recv_into(self, buffer_, nbytes=0, flags=0):
if self._raise_socket_error:
raise socket.error(self._raise_socket_error)
if nbytes == 0:
nbytes = len(buffer_)
try:
buffer_[0:nbytes] = self._server_replies[0:nbytes]
except (IndexError, TypeError) as err:
return 0
except ValueError:
pass
self._server_replies = self._server_replies[nbytes:]
return len(buffer_)
def send(self, string, flags=0):
if self._raise_socket_error:
raise socket.error(self._raise_socket_error)
self._client_sends.append(bytearray(string))
return len(string)
def sendall(self, string, flags=0):
self._client_sends.append(bytearray(string))
return None
def add_packet(self, packet):
self._server_replies += packet
def add_packets(self, packets):
for packet in packets:
self._server_replies += packet
def reset(self):
self._raise_socket_error = 0
self._server_replies = bytearray(b'')
self._client_sends = []
def get_address(self):
return 'dummy'
def get_test_modules():
"""Get list of Python modules containing tests
This function scans the tests/ folder for Python modules which name
start with 'test_'. It will return the dotted name of the module with
submodules together with the first line of the doc string found in
the module.
The result is a sorted list of tuples and each tuple is
(name, module_dotted_path, description)
For example:
('cext_connection', 'tests.cext.cext_connection', 'This module..')
Returns a list of tuples.
"""
global _CACHED_TESTCASES
if _CACHED_TESTCASES:
return _CACHED_TESTCASES
testcases = []
pattern = re.compile('.*test_(.*)')
for finder, name, is_pkg in walk_packages(__path__, prefix=__name__+'.'):
if ('.test_' not in name or
('django' in name and not DJANGO_VERSION) or
('fabric' in name and not FABRIC_CONFIG) or
('cext' in name and not MYSQL_CAPI)):
continue
module_path = os.path.join(finder.path, name.split('.')[-1] + '.py')
dsc = '(description not available)'
try:
mod = load_source(name, module_path)
except IOError as exc:
# Not Python source files
continue
except ImportError as exc:
check_c_extension(exc)
else:
try:
dsc = mod.__doc__.splitlines()[0]
except AttributeError:
# No description available
pass
testcases.append((pattern.match(name).group(1), name, dsc))
testcases.sort(key=lambda x: x[0], reverse=False)
# 'Unimport' modules so they can be correctly imported when tests run
for _, module, _ in testcases:
sys.modules.pop(module, None)
_CACHED_TESTCASES = testcases
return testcases
def get_test_names():
"""Get test names
This functions gets the names of Python modules containing tests. The
name is parsed from files prefixed with 'test_'. For example,
'test_cursor.py' has name 'cursor'.
Returns a list of strings.
"""
pattern = re.compile('.*test_(.*)')
return [mod[0] for mod in get_test_modules()]
def set_nr_mysql_servers(number):
"""Set the number of MySQL servers needed
This functions sets how much MySQL servers are needed for running the
unit tests. The number argument should be a integer between 1 and
16 (16 being the hard limit).
The set_nr_mysql_servers() function is used in test modules, usually at
the very top (after imports).
Raises AttributeError on errors.
"""
global MYSQL_SERVERS_NEEDED # pylint: disable=W0603
if not isinstance(number, int) or (number < 1 or number > 16):
raise AttributeError(
"number of MySQL servers should be a value between 1 and 16")
if number > MYSQL_SERVERS_NEEDED:
MYSQL_SERVERS_NEEDED = number
def fake_hostname():
"""Return a fake hostname
This function returns a string which can be used in the creation of
fake hostname. Note that we do not add a domain name.
Returns a string.
"""
if PY2:
return ''.join(["%02x" % ord(c) for c in os.urandom(4)])
else:
return ''.join(["%02x" % c for c in os.urandom(4)])
def get_mysql_config(name=None, index=None):
"""Get MySQL server configuration for running MySQL server
If no name is given, then we will return the configuration of the
first added.
"""
if not name and not index:
return MYSQL_SERVERS[0].client_config.copy()
if name:
for server in MYSQL_SERVERS:
if server.name == name:
return server.client_config.copy()
elif index:
return MYSQL_SERVERS[index].client_config.copy()
return None
def have_engine(cnx, engine):
"""Check support for given storage engine
This function checks if the MySQL server accessed through cnx has
support for the storage engine.
Returns True or False.
"""
have = False
engine = engine.lower()
cur = cnx.cursor()
# Should use INFORMATION_SCHEMA, but play nice with v4.1
cur.execute("SHOW ENGINES")
rows = cur.fetchall()
for row in rows:
if row[0].lower() == engine:
if row[1].lower() == 'yes':
have = True
break
cur.close()
return have
def cmp_result(result1, result2):
"""Compare results (list of tuples) coming from MySQL
For certain results, like SHOW VARIABLES or SHOW WARNINGS, the
order is unpredictable. To check if what is expected in the
tests, we need to compare each row.
Returns True or False.
"""
try:
if len(result1) != len(result2):
return False
for row in result1:
if row not in result2:
return False
except:
return False
return True
class UTCTimeZone(datetime.tzinfo):
"""UTC"""
def __init__(self):
pass
def utcoffset(self, dt):
return datetime.timedelta(0)
def dst(self, dt):
return datetime.timedelta(0)
def tzname(self, dt):
return 'UTC'
class TestTimeZone(datetime.tzinfo):
"""Test time zone"""
def __init__(self, hours=0):
self._offset = datetime.timedelta(hours=hours)
def utcoffset(self, dt):
return self._offset
def dst(self, dt):
return datetime.timedelta(0)
def tzname(self, dt):
return 'TestZone'
def cnx_config(**extra_config):
def _cnx_config(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not hasattr(self, 'config'):
self.config = get_mysql_config()
if extra_config:
for key, value in extra_config.items():
self.config[key] = value
func(self, *args, **kwargs)
return wrapper
return _cnx_config
def foreach_cnx(*cnx_classes, **extra_config):
def _use_cnx(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not hasattr(self, 'config'):
self.config = get_mysql_config()
if extra_config:
for key, value in extra_config.items():
self.config[key] = value
for cnx_class in cnx_classes or self.all_cnx_classes:
try:
self.cnx = cnx_class(**self.config)
self._testMethodName = "{0} (using {1})".format(
func.__name__, cnx_class.__name__)
except Exception as exc:
if hasattr(self, 'cnx'):
# We will rollback/close later
pass
else:
traceback.print_exc(file=sys.stdout)
raise exc
try:
func(self, *args, **kwargs)
except Exception as exc:
traceback.print_exc(file=sys.stdout)
raise exc
finally:
try:
self.cnx.rollback()
self.cnx.close()
except:
# Might already be closed.
pass
return wrapper
return _use_cnx
class MySQLConnectorTests(unittest.TestCase):
def __init__(self, methodName='runTest'):
from mysql.connector import connection
self.all_cnx_classes = [connection.MySQLConnection]
self.maxDiff = 64
try:
import _mysql_connector
from mysql.connector import connection_cext
except ImportError:
self.have_cext = False
else:
self.have_cext = True
self.all_cnx_classes.append(connection_cext.CMySQLConnection)
super(MySQLConnectorTests, self).__init__(methodName=methodName)
def __str__(self):
classname = strclass(self.__class__)
return "{classname}.{method}".format(
method=self._testMethodName,
classname=re.sub(r"tests\d*.test_", "", classname)
)
def check_attr(self, obj, attrname, default):
cls_name = obj.__class__.__name__
self.assertTrue(
hasattr(obj, attrname),
"{name} object has no '{attr}' attribute".format(
name=cls_name, attr=attrname))
self.assertEqual(
default,
getattr(obj, attrname),
"{name} object's '{attr}' should "
"default to {type_} '{default}'".format(
name=cls_name,
attr=attrname,
type_=type(default).__name__,
default=default))
def check_method(self, obj, method):
cls_name = obj.__class__.__name__
self.assertTrue(
hasattr(obj, method),
"{0} object has no '{1}' method".format(cls_name, method))
self.assertTrue(
inspect.ismethod(getattr(obj, method)),
"{0} object defines {1}, but is not a method".format(
cls_name, method))
def check_args(self, function, supported_arguments):
argspec = inspect.getargspec(function)
function_arguments = dict(zip(argspec[0][1:], argspec[3]))
for argument, default in function_arguments.items():
try:
self.assertEqual(
supported_arguments[argument],
default,
msg="Argument '{0}' has wrong default".format(argument))
except KeyError:
self.fail("Found unsupported or new argument '%s'" % argument)
for argument, default in supported_arguments.items():
if not argument in function_arguments:
self.fail("Supported argument '{0}' fails".format(argument))
if sys.version_info[0:2] >= (3, 4):
def _addSkip(self, result, test_case, reason):
add_skip = getattr(result, 'addSkip', None)
if add_skip:
add_skip(test_case, self._testMethodName + ': ' + reason)
else:
def _addSkip(self, result, reason):
add_skip = getattr(result, 'addSkip', None)
if add_skip:
add_skip(self, self._testMethodName + ': ' + reason)
if sys.version_info[0:2] == (2, 6):
# Backport handy asserts from 2.7
def assertIsInstance(self, obj, cls, msg=None):
if not isinstance(obj, cls):
msg = "{0} is not an instance of {1}".format(
unittest.util.safe_repr(obj), unittest.util.repr(cls))
self.fail(self._formatMessage(msg, msg))
def assertGreater(self, a, b, msg=None):
if not a > b:
msg = "{0} not greater than {1}".format(
unittest.util.safe_repr(a), unittest.util.safe_repr(b))
self.fail(self._formatMessage(msg, msg))
def run(self, result=None):
if sys.version_info[0:2] == (2, 6):
test_method = getattr(self, self._testMethodName)
if (getattr(self.__class__, "__unittest_skip__", False) or
getattr(test_method, "__unittest_skip__", False)):
# We skipped a class
try:
why = (
getattr(self.__class__, '__unittest_skip_why__', '')
or
getattr(test_method, '__unittest_skip_why__', '')
)
self._addSkip(result, why)
finally:
result.stopTest(self)
return
if PY2:
return super(MySQLConnectorTests, self).run(result)
else:
return super().run(result)
def check_namedtuple(self, tocheck, attrs):
for attr in attrs:
try:
getattr(tocheck, attr)
except AttributeError:
self.fail("Attribute '{0}' not part of namedtuple {1}".format(
attr, tocheck))
class TestsCursor(MySQLConnectorTests):
def _test_execute_setup(self, cnx, tbl="myconnpy_cursor", engine="MyISAM"):
self._test_execute_cleanup(cnx, tbl)
stmt_create = (
"CREATE TABLE {table} "
"(col1 INT, col2 VARCHAR(30), PRIMARY KEY (col1))"
"ENGINE={engine}").format(
table=tbl, engine=engine)
try:
cur = cnx.cursor()
cur.execute(stmt_create)
except Exception as err: # pylint: disable=W0703
self.fail("Failed setting up test table; {0}".format(err))
cur.close()
def _test_execute_cleanup(self, cnx, tbl="myconnpy_cursor"):
stmt_drop = "DROP TABLE IF EXISTS {table}".format(table=tbl)
try:
cur = cnx.cursor()
cur.execute(stmt_drop)
except Exception as err: # pylint: disable=W0703
self.fail("Failed cleaning up test table; {0}".format(err))
cur.close()
class CMySQLConnectorTests(MySQLConnectorTests):
def connc_connect_args(self, recache=False):
"""Get connection arguments for the MySQL C API
Get the connection arguments suitable for the MySQL C API
from the Connector/Python arguments. This method sets the member
variable connc_kwargs as well as returning a copy of connc_kwargs.
If recache is True, the information stored in connc_kwargs will
be refreshed.
:return: Dictionary containing connection arguments.
:rtype: dict
"""
self.config = get_mysql_config().copy()
if not self.hasattr('connc_kwargs') or recache is True:
connect_args = [
"host", "user", "password", "database",
"port", "unix_socket", "client_flags"
]
self.connc_kwargs = {}
for key, value in self.config.items():
if key in connect_args:
self.connect_kwargs[key] = value
return self.connc_kwargs.copy()
class CMySQLCursorTests(CMySQLConnectorTests):
_cleanup_tables = []
def setUp(self):
self.config = get_mysql_config()
# Import here allowed
from mysql.connector.connection_cext import CMySQLConnection
self.cnx = CMySQLConnection(**self.config)
def tearDown(self):
self.cleanup_tables(self.cnx)
self.cnx.close()
def setup_table(self, cnx, tbl="myconnpy_cursor", engine="InnoDB"):
self.cleanup_table(cnx, tbl)
stmt_create = (
"CREATE TABLE {table} "
"(col1 INT AUTO_INCREMENT, "
"col2 VARCHAR(30), "
"col3 INT NOT NULL DEFAULT 0, "
"PRIMARY KEY (col1))"
"ENGINE={engine}").format(
table=tbl, engine=engine)
try:
cnx.cmd_query(stmt_create)
except Exception as err: # pylint: disable=W0703
cnx.rollback()
self.fail("Failed setting up test table; {0}".format(err))
else:
cnx.commit()
self._cleanup_tables.append(tbl)
def cleanup_table(self, cnx, tbl="myconnpy_cursor"):
stmt_drop = "DROP TABLE IF EXISTS {table}".format(table=tbl)
# Explicit rollback: uncommited changes could otherwise block
cnx.rollback()
try:
cnx.cmd_query(stmt_drop)
except Exception as err: # pylint: disable=W0703
self.fail("Failed cleaning up test table; {0}".format(err))
if tbl in self._cleanup_tables:
self._cleanup_tables.remove(tbl)
def cleanup_tables(self, cnx):
for tbl in self._cleanup_tables:
self.cleanup_table(cnx, tbl)
def printmsg(msg=None):
if msg is not None:
print(msg)
class SkipTest(Exception):
"""Exception compatible with SkipTest of Python v2.7 and later"""
def _id(obj):
"""Function defined in unittest.case which is needed for decorators"""
return obj
def test_skip(reason):
"""Skip test
This decorator is used by Python v2.6 code to keep compatible with
Python v2.7 (and later) unittest.skip.
"""
def decorator(test):
if not isinstance(test, (type, types.ClassType)):
@wraps(test)
def wrapper(*args, **kwargs):
raise SkipTest(reason)
test = wrapper
test.__unittest_skip__ = True
test.__unittest_skip_why__ = reason
return test
return decorator
def test_skip_if(condition, reason):
"""Skip test if condition is true
This decorator is used by Python v2.6 code to keep compatible with
Python v2.7 (and later) unittest.skipIf.
"""
if condition:
return test_skip(reason)
return _id
def setup_logger(logger, debug=False, logfile=None):
"""Setting up the logger"""
formatter = logging.Formatter(
"%(asctime)s [%(name)s:%(levelname)s] %(message)s")
handler = None
if logfile:
handler = logging.FileHandler(logfile)
else:
handler = logging.StreamHandler()
handler.setFormatter(formatter)
if debug:
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.INFO)
LOGGER.handlers = [] # We only need one handler
LOGGER.addHandler(handler)
def install_connector(root_dir, install_dir, connc_location=None):
"""Install Connector/Python in working directory
"""
logfile = 'myconnpy_install.log'
LOGGER.info("Installing Connector/Python in {0}".format(install_dir))
try:
# clean up previous run
if os.path.exists(logfile):
os.unlink(logfile)
shutil.rmtree(install_dir)
except OSError:
pass
cmd = [
sys.executable,
'setup.py',
'clean', '--all', # necessary for removing the build/
]
cmd.extend([
'install',
'--root', install_dir,
'--install-lib', '.'
])
if connc_location:
cmd.extend(['--static', '--with-mysql-capi', connc_location])
prc = subprocess.Popen(cmd, stdin=subprocess.PIPE,
stderr=subprocess.STDOUT, stdout=subprocess.PIPE,
cwd=root_dir)
stdout = prc.communicate()[0]
if prc.returncode is not 0:
with open(logfile, 'wb') as logfp:
logfp.write(stdout)
LOGGER.error("Failed installing Connector/Python, see {log}".format(
log=logfile))
sys.exit(1)
def check_c_extension(exc=None):
"""Check whether we can load the C Extension
This function needs the location of the mysql_config tool to
figure out the location of the MySQL Connector/C libraries. On
Windows it would be the installation location of Connector/C.
:param mysql_config: Location of the mysql_config tool
:param exc: An ImportError exception
"""
if not MYSQL_CAPI:
return
if platform.system() == "Darwin":
libpath_var = 'DYLD_LIBRARY_PATH'
elif platform.system() == "Windows":
libpath_var = 'PATH'
else:
libpath_var = 'LD_LIBRARY_PATH'
if not os.path.exists(MYSQL_CAPI):
LOGGER.error("MySQL Connector/C not available using '%s'", MYSQL_CAPI)
if os.name == 'posix':
if os.path.isdir(MYSQL_CAPI):
mysql_config = os.path.join(MYSQL_CAPI, 'bin', 'mysql_config')
else:
mysql_config = MYSQL_CAPI
lib_dir = get_mysql_config_info(mysql_config)['lib_dir']
elif os.path.isdir(MYSQL_CAPI):
lib_dir = os.path.join(MYSQL_CAPI, 'lib')
else:
LOGGER.error("C Extension not supported on %s", os.name)
sys.exit(1)
error_msg = ''
if not exc:
try:
import _mysql_connector
except ImportError as exc:
error_msg = str(exc).strip()
else:
assert(isinstance(exc, ImportError))
error_msg = str(exc).strip()
if not error_msg:
# Nothing to do
return
match = re.match('.*Library not loaded:\s(.+)\n.*', error_msg)
if match:
lib_name = match.group(1)
LOGGER.error(
"MySQL Client library not loaded. Make sure the shared library "
"'%s' can be loaded by Python. Tip: Add folder '%s' to "
"environment variable '%s'.",
lib_name, lib_dir, libpath_var)
sys.exit(1)
else:
LOGGER.error("C Extension not available: %s", error_msg)
sys.exit(1)

View File

@@ -1,770 +0,0 @@
# -*- coding: utf-8 -*-
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2014, 2016, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Incur., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Testing the C Extension MySQL C API
"""
import logging
import os
import re
import unittest
import tests
from mysql.connector.constants import ServerFlag, ClientFlag
try:
from _mysql_connector import MySQL, MySQLError, MySQLInterfaceError
except ImportError:
CEXT_MYSQL_AVAILABLE = False
else:
CEXT_MYSQL_AVAILABLE = True
LOGGER = logging.getLogger(tests.LOGGER_NAME)
def get_variables(cnx, pattern=None, variables=None, global_vars=False):
"""Get session or global system variables
We use the MySQL connection cnx to query the INFORMATION_SCHEMA
table SESSION_VARIABLES or, when global_vars is True, the table
GLOBAL_VARIABLES.
:param cnx: Valid MySQL connection
:param pattern: Pattern to use (used for LIKE)
:param variables: Variables to query for
:return: Dictionary containing variables with values or empty dict
:rtype : dict
"""
format_vars = {
'where_clause': '',
'where': '',
}
ver = cnx.get_server_version()
if ver >= (5, 7, 6):
table_global_vars = 'global_variables'
table_session_vars = 'session_variables'
format_vars['schema'] = 'performance_schema'
else:
table_global_vars = 'GLOBAL_VARIABLES'
table_session_vars = 'SESSION_VARIABLES'
format_vars['schema'] = 'INFORMATION_SCHEMA'
if global_vars is True:
format_vars['table'] = table_global_vars
else:
format_vars['table'] = table_session_vars
query = "SELECT * FROM {schema}.{table} {where} {where_clause}"
where = []
if pattern:
where.append('VARIABLE_NAME LIKE "{0}"'.format(pattern))
if variables:
where.append('VARIABLE_NAME IN ({0})'.format(
','.join([ "'{0}'".format(name) for name in variables ])
))
if where:
format_vars['where'] = 'WHERE'
format_vars['where_clause'] = ' OR '.join(where)
cnx.query(query.format(**format_vars))
result = {}
row = cnx.fetch_row()
while row:
result[row[0].lower()] = row[1]
row = cnx.fetch_row()
cnx.free_result()
return result
def fetch_rows(cnx, query=None):
"""Execute query and fetch first result set
This function will use connection cnx and execute the query. All
rows are then returned as a list of tuples.
:param cnx: Valid MySQL connection
:param query: SQL statement to execute
:return: List of tuples
:rtype: list
"""
rows = []
if query:
cnx.query(query)
if cnx.have_result_set:
row = cnx.fetch_row()
while row:
rows.append(row)
row = cnx.fetch_row()
if cnx.next_result():
raise Exception("fetch_rows does not work with multi results")
cnx.free_result()
return rows
@unittest.skipIf(CEXT_MYSQL_AVAILABLE == False, "C Extension not available")
class CExtMySQLTests(tests.MySQLConnectorTests):
"""Test the MySQL class in the C Extension"""
def setUp(self):
self.config = tests.get_mysql_config()
connect_args = [
"host", "user", "password", "database",
"port", "unix_socket", "client_flags"
]
self.connect_kwargs = {}
for key, value in self.config.items():
if key in connect_args:
self.connect_kwargs[key] = value
if 'client_flags' not in self.connect_kwargs:
self.connect_kwargs['client_flags'] = ClientFlag.get_default()
def test___init__(self):
cmy = MySQL()
self.assertEqual(False, cmy.buffered())
self.assertEqual(False, cmy.raw())
cmy = MySQL(buffered=True, raw=True)
self.assertEqual(True, cmy.buffered())
self.assertEqual(True, cmy.raw())
exp = 'gbk'
cmy = MySQL(charset_name=exp)
cmy.connect(**self.connect_kwargs)
self.assertEqual(exp, cmy.character_set_name())
def test_buffered(self):
cmy = MySQL()
self.assertEqual(False, cmy.buffered())
cmy.buffered(True)
self.assertEqual(True, cmy.buffered())
cmy.buffered(False)
self.assertEqual(False, cmy.buffered())
self.assertRaises(TypeError, cmy.buffered, 'a')
def test_raw(self):
cmy = MySQL()
self.assertEqual(False, cmy.raw())
cmy.raw(True)
self.assertEqual(True, cmy.raw())
cmy.raw(False)
self.assertEqual(False, cmy.raw())
self.assertRaises(TypeError, cmy.raw, 'a')
def test_connected(self):
config = self.connect_kwargs.copy()
cmy = MySQL()
self.assertFalse(cmy.connected())
cmy.connect(**config)
self.assertTrue(cmy.connected())
cmy.close()
self.assertFalse(cmy.connected())
def test_connect(self):
config = self.connect_kwargs.copy()
cmy = MySQL()
self.assertFalse(cmy.ping())
# Using Unix Socket
cmy.connect(**config)
self.assertTrue(cmy.ping())
# Using TCP
config['unix_socket'] = None
cmy.connect(**config)
self.assertTrue(cmy.ping())
self.assertEqual(None, cmy.close())
self.assertFalse(cmy.ping())
self.assertEqual(None, cmy.close())
def test_close(self):
"""
MySQL_close() is being tested in test_connected
Unless something needs to be tested additionally, leave this
test case as placeholder.
"""
pass
def test_ping(self):
"""
MySQL_ping() is being tested in test_connected
Unless something needs to be tested additionally, leave this
test case as placeholder.
"""
pass
def test_escape_string(self):
cases = [
('new\nline', b'new\\nline'),
('carriage\rreturn', b'carriage\\rreturn'),
('control\x1aZ', b'control\\ZZ'),
("single'quote", b"single\\'quote"),
('double"quote', b'double\\"quote'),
('back\slash', b'back\\\\slash'),
('nul\0char', b'nul\\0char'),
(u"Kangxi⽃\0", b'Kangxi\xe2\xbd\x83\\0\xe2\xbd\x87'),
(b'bytes\0ob\'j\n"ct\x1a', b'bytes\\0ob\\\'j\\n\\"ct\\Z'),
]
cmy = MySQL()
cmy.connect(**self.connect_kwargs)
unicode_string = u"Kangxi⽃\0"
self.assertRaises(UnicodeEncodeError, cmy.escape_string, unicode_string)
cmy.set_character_set("UTF8")
for value, exp in cases:
self.assertEqual(exp, cmy.escape_string(value))
self.assertRaises(TypeError, cmy.escape_string, 1234);
def test_get_character_set_info(self):
cmy = MySQL()
self.assertRaises(MySQLInterfaceError, cmy.get_character_set_info)
cmy.connect(**self.connect_kwargs)
# We go by the default of MySQL, which is latin1/swedish_ci
exp = {'comment': '', 'name': 'latin1_swedish_ci',
'csname': 'latin1', 'mbmaxlen': 1, 'number': 8, 'mbminlen': 1}
result = cmy.get_character_set_info()
# make 'comment' deterministic
result['comment'] = ''
self.assertEqual(exp, result)
cmy.query("SET NAMES utf8")
cmy.set_character_set('utf8')
exp = {'comment': '', 'name': 'utf8_general_ci',
'csname': 'utf8', 'mbmaxlen': 3, 'number': 33, 'mbminlen': 1}
result = cmy.get_character_set_info()
# make 'comment' deterministic
result['comment'] = ''
self.assertEqual(exp, result)
def test_get_proto_info(self):
cmy = MySQL()
self.assertRaises(MySQLInterfaceError, cmy.get_proto_info)
cmy.connect(**self.connect_kwargs)
self.assertEqual(10, cmy.get_proto_info())
def test_get_server_info(self):
cmy = MySQL()
self.assertRaises(MySQLInterfaceError, cmy.get_server_info)
cmy.connect(**self.connect_kwargs)
version = cmy.get_server_version()
info = cmy.get_server_info()
self.assertIsInstance(info, str)
self.assertTrue(info.startswith('.'.join([str(v) for v in version])))
def test_get_server_version(self):
cmy = MySQL()
self.assertRaises(MySQLInterfaceError, cmy.get_server_version)
cmy.connect(**self.connect_kwargs)
version = cmy.get_server_version()
self.assertIsInstance(version, tuple)
self.assertEqual(3, len(version))
self.assertTrue(all([isinstance(v, int) and v > 0 for v in version]))
self.assertTrue(3 < version[0] < 7)
self.assertTrue(0 < version[1] < 20)
self.assertTrue(0 < version[2] < 99)
def test_thread_id(self):
cmy = MySQL()
self.assertRaises(MySQLInterfaceError, cmy.thread_id)
cmy.connect(**self.connect_kwargs)
if tests.PY2:
self.assertIsInstance(cmy.thread_id(), long)
else:
self.assertIsInstance(cmy.thread_id(), int)
self.assertGreater(cmy.thread_id(), 0)
thread_id = cmy.thread_id()
cmy.close()
self.assertRaises(MySQLError, cmy.thread_id)
def test_select_db(self):
cmy = MySQL(buffered=True)
cmy.connect(**self.connect_kwargs)
cmy.select_db('mysql')
cmy.query("SELECT DATABASE()")
self.assertEqual(b'mysql', cmy.fetch_row()[0])
cmy.free_result()
cmy.select_db('myconnpy')
cmy.query("SELECT DATABASE()")
self.assertEqual(b'myconnpy', cmy.fetch_row()[0])
cmy.free_result()
def test_affected_rows(self):
cmy = MySQL(buffered=True)
cmy.connect(**self.connect_kwargs)
table = "affected_rows"
cmy.select_db('myconnpy')
cmy.query("DROP TABLE IF EXISTS {0}".format(table))
cmy.query("CREATE TABLE {0} (c1 INT, c2 INT)".format(table))
cmy.query("INSERT INTO {0} (c1, c2) VALUES "
"(1, 10), (2, 20), (3, 30)".format(table))
self.assertEqual(3, cmy.affected_rows())
cmy.query("UPDATE {0} SET c2 = c2 + 1 WHERE c1 < 3".format(table))
self.assertEqual(2, cmy.affected_rows())
cmy.query("DELETE FROM {0} WHERE c1 IN (1, 2, 3)".format(table))
self.assertEqual(3, cmy.affected_rows())
cmy.query("DROP TABLE IF EXISTS {0}".format(table))
def test_field_count(self):
cmy = MySQL(buffered=True)
cmy.connect(**self.connect_kwargs)
table = "field_count"
cmy.select_db('myconnpy')
cmy.query("DROP TABLE IF EXISTS {0}".format(table))
cmy.query("CREATE TABLE {0} (c1 INT, c2 INT, c3 INT)".format(table))
cmy.query("SELECT * FROM {0}".format(table))
self.assertEqual(3, cmy.field_count())
cmy.free_result()
cmy.query("INSERT INTO {0} (c1, c2, c3) VALUES "
"(1, 10, 100)".format(table))
cmy.commit()
cmy.query("SELECT * FROM {0}".format(table))
self.assertEqual(3, cmy.field_count())
cmy.free_result()
cmy.query("DROP TABLE IF EXISTS {0}".format(table))
def test_autocommit(self):
cmy1 = MySQL(buffered=True)
cmy1.connect(**self.connect_kwargs)
cmy2 = MySQL(buffered=True)
cmy2.connect(**self.connect_kwargs)
self.assertRaises(ValueError, cmy1.autocommit, 'ham')
self.assertRaises(ValueError, cmy1.autocommit, 1)
self.assertRaises(ValueError, cmy1.autocommit, None)
table = "autocommit_test"
# For the test we start off by making sure the autocommit is off
# for both sessions
cmy1.query("SELECT @@global.autocommit")
if cmy1.fetch_row()[0] != 1:
cmy1.query("SET @@session.autocommit = 0")
cmy2.query("SET @@session.autocommit = 0")
cmy1.query("DROP TABLE IF EXISTS {0}".format(table))
cmy1.query("CREATE TABLE {0} (c1 INT)".format(table))
# Turn AUTOCOMMIT on
cmy1.autocommit(True)
cmy1.query("INSERT INTO {0} (c1) VALUES "
"(1), (2), (3)".format(table))
cmy2.query("SELECT * FROM {0}".format(table))
self.assertEqual(3, cmy2.num_rows())
rows = fetch_rows(cmy2)
# Turn AUTOCOMMIT off
cmy1.autocommit(False)
cmy1.query("INSERT INTO {0} (c1) VALUES "
"(4), (5), (6)".format(table))
cmy2.query("SELECT * FROM {0} WHERE c1 > 3".format(table))
self.assertEqual([], fetch_rows(cmy2))
cmy1.commit()
cmy2.query("SELECT * FROM {0} WHERE c1 > 3".format(table))
self.assertEqual([(4,), (5,), (6,)], fetch_rows(cmy2))
cmy1.close()
cmy2.close()
def test_commit(self):
cmy1 = MySQL(buffered=True)
cmy1.connect(**self.connect_kwargs)
cmy2 = MySQL(buffered=True)
cmy2.connect(**self.connect_kwargs)
table = "commit_test"
cmy1.query("DROP TABLE IF EXISTS {0}".format(table))
cmy1.query("CREATE TABLE {0} (c1 INT)".format(table))
cmy1.query("START TRANSACTION")
cmy1.query("INSERT INTO {0} (c1) VALUES "
"(1), (2), (3)".format(table))
cmy2.query("SELECT * FROM {0}".format(table))
self.assertEqual([], fetch_rows(cmy2))
cmy1.commit()
cmy2.query("SELECT * FROM {0}".format(table))
self.assertEqual([(1,), (2,), (3,)], fetch_rows(cmy2))
cmy1.query("DROP TABLE IF EXISTS {0}".format(table))
def test_change_user(self):
connect_kwargs = self.connect_kwargs.copy()
connect_kwargs['unix_socket'] = None
cmy1 = MySQL(buffered=True)
cmy1.connect(**connect_kwargs)
cmy2 = MySQL(buffered=True)
cmy2.connect(**connect_kwargs)
new_user = {
'user': 'cextuser',
'host': self.config['host'],
'database': self.connect_kwargs['database'],
'password': 'connc',
}
try:
cmy1.query("DROP USER '{user}'@'{host}'".format(**new_user))
except MySQLInterfaceError:
# Probably not created
pass
stmt = ("CREATE USER '{user}'@'{host}' IDENTIFIED WITH "
"mysql_native_password").format(**new_user)
cmy1.query(stmt)
cmy1.query("SET old_passwords = 0")
res = cmy1.query("SET PASSWORD FOR '{user}'@'{host}' = "
"PASSWORD('{password}')".format(**new_user))
cmy1.query("GRANT ALL ON {database}.* "
"TO '{user}'@'{host}'".format(**new_user))
cmy2.query("SHOW GRANTS FOR {user}@{host}".format(**new_user))
cmy2.query("SELECT USER()")
orig_user = cmy2.fetch_row()[0]
cmy2.free_result()
cmy2.change_user(user=new_user['user'], password=new_user['password'],
database=new_user['database'])
cmy2.query("SELECT USER()")
current_user = cmy2.fetch_row()[0]
self.assertNotEqual(orig_user, current_user)
self.assertTrue(
u"{user}@".format(**new_user) in current_user.decode('utf8'))
cmy2.free_result()
def test_character_set_name(self):
cmy1 = MySQL(buffered=True)
self.assertRaises(MySQLInterfaceError, cmy1.character_set_name)
cmy1.connect(**self.connect_kwargs)
self.assertEqual('latin1', cmy1.character_set_name())
def test_set_character_set(self):
cmy1 = MySQL(buffered=True)
self.assertRaises(MySQLInterfaceError, cmy1.set_character_set, 'latin2')
cmy1.connect(**self.connect_kwargs)
orig = cmy1.character_set_name()
cmy1.set_character_set('utf8')
charset = cmy1.character_set_name()
self.assertNotEqual(orig, charset)
self.assertEqual('utf8', charset)
self.assertRaises(MySQLInterfaceError,
cmy1.set_character_set, 'ham_spam')
variables = ('character_set_connection',)
exp = {b'character_set_connection': b'utf8',}
self.assertEqual(exp, get_variables(cmy1, variables=variables))
exp = {b'character_set_connection': b'big5',}
cmy1.set_character_set('big5')
self.assertEqual(exp, get_variables(cmy1, variables=variables))
def test_get_ssl_cipher(self):
cmy1 = MySQL(buffered=True)
self.assertRaises(MySQLInterfaceError, cmy1.get_ssl_cipher)
cmy1.connect(**self.connect_kwargs)
self.assertEqual(None, cmy1.get_ssl_cipher())
def test_hex_string(self):
config = self.connect_kwargs.copy()
cmy = MySQL(buffered=True)
table = "hex_string"
cases = {
'utf8': [
(u'ham', b"X'68616D'"),
],
'big5': [
(u'\u5C62', b"X'B9F0'")
],
'sjis': [
(u'\u005c', b"X'5C'"),
],
'gbk': [
(u'赵孟頫', b"X'D5D4C3CFEE5C'"),
(u'\\\\', b"X'D5D45CC3CF5CEE5C5C'"),
(u'', b"X'DF64'")
],
'ascii': [
('\x5c\x00\x5c', b"X'5C005C'"),
],
}
cmy.connect(**config)
def create_table(charset):
cmy.query("DROP TABLE IF EXISTS {0}".format(table))
cmy.query("CREATE TABLE {0} (id INT, "
"c1 VARCHAR(400)) CHARACTER SET {1}".format(
table, charset))
insert = "INSERT INTO {0} (id, c1) VALUES ({{id}}, {{hex}})".format(
table)
select = "SELECT c1 FROM {0} WHERE id = {{id}}".format(table)
for encoding, data in cases.items():
create_table(encoding)
for i, info in enumerate(data):
case, exp = info
cmy.set_character_set(encoding)
hexed = cmy.hex_string(case.encode(encoding))
self.assertEqual(exp, hexed)
cmy.query(insert.format(id=i, hex=hexed.decode()))
cmy.query(select.format(id=i))
try:
fetched = fetch_rows(cmy)[0][0]
except UnicodeEncodeError:
self.fail("Could not encode {0}".format(encoding))
self.assertEqual(case, fetched.decode(encoding),
"Failed with case {0}/{1}".format(i, encoding))
cmy.query("DROP TABLE IF EXISTS {0}".format(table))
def test_insert_id(self):
cmy = MySQL(buffered=True)
cmy.connect(**self.connect_kwargs)
table = "insert_id_test"
cmy.query("DROP TABLE IF EXISTS {0}".format(table))
cmy.query("CREATE TABLE {0} (id INT AUTO_INCREMENT KEY)".format(table))
self.assertEqual(0, cmy.insert_id())
cmy.query("INSERT INTO {0} VALUES ()".format(table))
self.assertEqual(1, cmy.insert_id())
# Multiple-row
cmy.query("INSERT INTO {0} VALUES (), ()".format(table))
self.assertEqual(2, cmy.insert_id())
cmy.query("DROP TABLE IF EXISTS {0}".format(table))
def test_warning_count(self):
cmy = MySQL()
cmy.connect(**self.connect_kwargs)
cmy.query("SELECT 'a' + 'b'", buffered=False)
fetch_rows(cmy)
self.assertEqual(2, cmy.warning_count())
cmy.query("SELECT 1 + 1", buffered=True)
self.assertEqual(0, cmy.warning_count())
fetch_rows(cmy)
def test_get_client_info(self):
cmy = MySQL(buffered=True)
match = re.match(r"(\d+\.\d+.\d+)(.*)", cmy.get_client_info())
self.assertNotEqual(None, match)
def test_get_client_version(self):
cmy = MySQL(buffered=True)
version = cmy.get_client_version()
self.assertTrue(isinstance(version, tuple))
self.assertTrue(all([ isinstance(v, int) for v in version]))
def test_get_host_info(self):
config = self.connect_kwargs.copy()
cmy = MySQL(buffered=True)
self.assertRaises(MySQLInterfaceError, cmy.get_host_info)
cmy.connect(**config)
if os.name == 'posix':
# On POSIX systems we would be connected by UNIX socket
self.assertTrue('via UNIX socket' in cmy.get_host_info())
# Connect using TCP/IP
config['unix_socket'] = None
cmy.connect(**config)
self.assertTrue('via TCP/IP' in cmy.get_host_info())
def test_query(self):
config = self.connect_kwargs.copy()
cmy = MySQL(buffered=True)
self.assertRaises(MySQLInterfaceError, cmy.query)
cmy.connect(**config)
self.assertRaises(MySQLInterfaceError, cmy.query, "SELECT spam")
self.assertTrue(cmy.query("SET @ham = 4"))
self.assertEqual(None, cmy.num_fields())
self.assertEqual(0, cmy.field_count())
self.assertTrue(cmy.query("SELECT @ham"))
self.assertEqual(4, cmy.fetch_row()[0])
self.assertEqual(None, cmy.fetch_row())
cmy.free_result()
self.assertTrue(cmy.query("SELECT 'ham', 'spam', 5", raw=True))
row = cmy.fetch_row()
self.assertTrue(isinstance(row[0], bytearray))
self.assertEqual(bytearray(b'spam'), row[1])
self.assertEqual(None, cmy.fetch_row())
cmy.free_result()
def test_st_server_status(self):
config = self.connect_kwargs.copy()
cmy = MySQL(buffered=True)
self.assertEqual(0, cmy.st_server_status())
cmy.connect(**config)
self.assertTrue(
cmy.st_server_status() & ServerFlag.STATUS_AUTOCOMMIT)
cmy.autocommit(False)
self.assertFalse(
cmy.st_server_status() & ServerFlag.STATUS_AUTOCOMMIT)
cmy.query("START TRANSACTION")
self.assertTrue(
cmy.st_server_status() & ServerFlag.STATUS_IN_TRANS)
cmy.query("ROLLBACK")
self.assertFalse(
cmy.st_server_status() & ServerFlag.STATUS_IN_TRANS)
def test_rollback(self):
cmy1 = MySQL(buffered=True)
cmy1.connect(**self.connect_kwargs)
cmy2 = MySQL(buffered=True)
cmy2.connect(**self.connect_kwargs)
table = "commit_test"
cmy1.query("DROP TABLE IF EXISTS {0}".format(table))
cmy1.query("CREATE TABLE {0} (c1 INT)".format(table))
cmy1.query("START TRANSACTION")
cmy1.query("INSERT INTO {0} (c1) VALUES "
"(1), (2), (3)".format(table))
cmy1.commit()
cmy2.query("SELECT * FROM {0}".format(table))
self.assertEqual([(1,), (2,), (3,)], fetch_rows(cmy2))
cmy1.query("START TRANSACTION")
cmy1.query("INSERT INTO {0} (c1) VALUES "
"(4), (5), (6)".format(table))
cmy1.rollback()
cmy2.query("SELECT * FROM {0}".format(table))
self.assertEqual(3, cmy2.num_rows())
cmy1.query("DROP TABLE IF EXISTS {0}".format(table))
def test_next_result(self):
cmy = MySQL()
cmy.connect(**self.connect_kwargs)
table = "next_result_test"
cmy.query("DROP TABLE IF EXISTS {0}".format(table))
cmy.query("CREATE TABLE {0} (c1 INT AUTO_INCREMENT KEY)".format(table))
var_names = ('"HAVE_CRYPT"', '"CHARACTER_SET_CONNECTION"')
queries = (
"SELECT 'HAM'",
"INSERT INTO {0} () VALUES ()".format(table),
"SELECT 'SPAM'",
)
exp = [
[(b'HAM',)],
{'insert_id': 1, 'affected': 1},
[(b'SPAM',)]
]
result = []
have_more = cmy.query(';'.join(queries))
self.assertTrue(have_more)
while have_more:
if cmy.have_result_set:
rows = []
row = cmy.fetch_row()
while row:
rows.append(row)
row = cmy.fetch_row()
result.append(rows)
else:
result.append({
"affected": cmy.affected_rows(),
"insert_id": cmy.insert_id()
})
have_more = cmy.next_result()
self.assertEqual(exp, result)

View File

@@ -1,123 +0,0 @@
# -*- coding: utf-8 -*-
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2014, 2016, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Incur., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Testing connection.CMySQLConnection class using the C Extension
"""
import tests
from mysql.connector import errors
from mysql.connector.constants import ClientFlag, flag_is_set
from mysql.connector.connection import MySQLConnection
from mysql.connector.connection_cext import CMySQLConnection
class CMySQLConnectionTests(tests.MySQLConnectorTests):
def setUp(self):
config = tests.get_mysql_config()
self.cnx = CMySQLConnection(**config)
self.pcnx = MySQLConnection(**config)
def test__info_query(self):
query = "SELECT 1, 'a', 2, 'b'"
exp = (1, 'a', 2, 'b')
self.assertEqual(exp, self.cnx.info_query(query))
self.assertRaises(errors.InterfaceError, self.cnx.info_query,
"SHOW VARIABLES LIKE '%char%'")
def test_client_flags(self):
defaults = ClientFlag.default
set_flags = self.cnx._cmysql.st_client_flag()
for flag in defaults:
self.assertTrue(flag_is_set(flag, set_flags))
def test_get_rows(self):
self.assertRaises(errors.InternalError, self.cnx.get_rows)
query = "SHOW STATUS LIKE 'Aborted_c%'"
self.cnx.cmd_query(query)
self.assertRaises(AttributeError, self.cnx.get_rows, 0)
self.assertRaises(AttributeError, self.cnx.get_rows, -10)
self.assertEqual(2, len(self.cnx.get_rows()))
self.cnx.free_result()
self.cnx.cmd_query(query)
self.assertEqual(1, len(self.cnx.get_rows(count=1)))
self.assertEqual(1, len(self.cnx.get_rows(count=1)))
self.assertEqual([], self.cnx.get_rows(count=1))
self.cnx.free_result()
def test_cmd_init_db(self):
query = "SELECT DATABASE()"
self.cnx.cmd_init_db('mysql')
self.assertEqual('mysql', self.cnx.info_query(query)[0])
self.cnx.cmd_init_db('myconnpy')
self.assertEqual('myconnpy', self.cnx.info_query(query)[0])
def test_cmd_query(self):
query = "SHOW STATUS LIKE 'Aborted_c%'"
info = self.cnx.cmd_query(query)
exp = {
'eof': {'status_flag': 32, 'warning_count': 0},
'columns': [
['Variable_name', 253, None, None, None, None, 0, 1],
('Value', 253, None, None, None, None, 1, 0)
]
}
if tests.MYSQL_VERSION >= (5, 7, 10):
exp['columns'][0][7] = 4097
exp['eof']['status_flag'] = 16385
exp['columns'][0] = tuple(exp['columns'][0])
self.assertEqual(exp, info)
rows = self.cnx.get_rows()
vars = [ row[0] for row in rows ]
self.assertEqual(2, len(rows))
vars.sort()
exp = ['Aborted_clients', 'Aborted_connects']
self.assertEqual(exp, vars)
exp = ['Value', 'Variable_name']
fields = [fld[0] for fld in info['columns']]
fields.sort()
self.assertEqual(exp, fields)
self.cnx.free_result()
info = self.cnx.cmd_query("SET @a = 1")
exp = {
'warning_count': 0, 'insert_id': 0, 'affected_rows': 0,
'server_status': 0, 'field_count': 0
}
self.assertEqual(exp, info)

View File

@@ -1,573 +0,0 @@
# -*- coding: utf-8 -*-
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2014, 2015, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Incur., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Testing the C Extension cursors
"""
import logging
import unittest
from mysql.connector import errors, errorcode
import tests
try:
from _mysql_connector import (
MySQL, MySQLError, MySQLInterfaceError,
)
except ImportError:
HAVE_CMYSQL = False
else:
HAVE_CMYSQL = True
from mysql.connector.connection_cext import CMySQLConnection
from mysql.connector.cursor_cext import (
CMySQLCursor, CMySQLCursorBuffered, CMySQLCursorRaw
)
LOGGER = logging.getLogger(tests.LOGGER_NAME)
@unittest.skipIf(HAVE_CMYSQL == False, "C Extension not available")
class CExtMySQLCursorTests(tests.CMySQLCursorTests):
def _get_cursor(self, cnx=None):
if not cnx:
cnx = CMySQLConnection(**self.config)
return CMySQLCursor(connection=cnx)
def test___init__(self):
self.assertRaises(errors.InterfaceError, CMySQLCursor, connection='ham')
cur = self._get_cursor(self.cnx)
self.assertTrue(hex(id(self.cnx)).upper()[2:]
in repr(cur._cnx).upper())
def test_lastrowid(self):
cur = self._get_cursor(self.cnx)
tbl = 'test_lastrowid'
self.setup_table(self.cnx, tbl)
cur.execute("INSERT INTO {0} (col1) VALUES (1)".format(tbl))
self.assertEqual(1, cur.lastrowid)
cur.execute("INSERT INTO {0} () VALUES ()".format(tbl))
self.assertEqual(2, cur.lastrowid)
cur.execute("INSERT INTO {0} () VALUES (),()".format(tbl))
self.assertEqual(3, cur.lastrowid)
cur.execute("INSERT INTO {0} () VALUES ()".format(tbl))
self.assertEqual(5, cur.lastrowid)
def test__fetch_warnings(self):
self.cnx.get_warnings = True
cur = self._get_cursor(self.cnx)
cur._cnx = None
self.assertRaises(errors.InterfaceError, cur._fetch_warnings)
cur = self._get_cursor(self.cnx)
cur.execute("SELECT 'a' + 'b'")
cur.fetchall()
exp = [
('Warning', 1292, "Truncated incorrect DOUBLE value: 'a'"),
('Warning', 1292, "Truncated incorrect DOUBLE value: 'b'")
]
res = cur._fetch_warnings()
self.assertTrue(tests.cmp_result(exp, res))
self.assertEqual(len(exp), cur._warning_count)
def test_execute(self):
self.cnx.get_warnings = True
cur = self._get_cursor(self.cnx)
self.assertEqual(None, cur.execute(None))
self.assertRaises(errors.ProgrammingError, cur.execute,
'SELECT %s,%s,%s', ('foo', 'bar',))
cur.execute("SELECT 'a' + 'b'")
cur.fetchall()
exp = [
('Warning', 1292, "Truncated incorrect DOUBLE value: 'a'"),
('Warning', 1292, "Truncated incorrect DOUBLE value: 'b'")
]
self.assertTrue(tests.cmp_result(exp, cur._warnings))
self.cnx.get_warnings = False
cur.execute("SELECT BINARY 'ham'")
exp = [(b'ham',)]
self.assertEqual(exp, cur.fetchall())
cur.close()
tbl = 'myconnpy_cursor'
self.setup_table(self.cnx, tbl)
cur = self._get_cursor(self.cnx)
stmt_insert = "INSERT INTO {0} (col1,col2) VALUES (%s,%s)".format(tbl)
res = cur.execute(stmt_insert, (1, 100))
self.assertEqual(None, res, "Return value of execute() is wrong.")
stmt_select = "SELECT col1,col2 FROM {0} ORDER BY col1".format(tbl)
cur.execute(stmt_select)
self.assertEqual([(1, '100')],
cur.fetchall(), "Insert test failed")
data = {'id': 2}
stmt = "SELECT col1,col2 FROM {0} WHERE col1 <= %(id)s".format(tbl)
cur.execute(stmt, data)
self.assertEqual([(1, '100')], cur.fetchall())
cur.close()
def test_executemany__errors(self):
self.cnx.get_warnings = True
cur = self._get_cursor(self.cnx)
self.assertEqual(None, cur.executemany(None, []))
cur = self._get_cursor(self.cnx)
self.assertRaises(errors.ProgrammingError, cur.executemany,
'programming error with string', 'foo')
self.assertRaises(errors.ProgrammingError, cur.executemany,
'programming error with 1 element list', ['foo'])
self.assertEqual(None, cur.executemany('empty params', []))
self.assertEqual(None, cur.executemany('params is None', None))
self.assertRaises(errors.ProgrammingError, cur.executemany,
'foo', ['foo'])
self.assertRaises(errors.ProgrammingError, cur.executemany,
'SELECT %s', [('foo',), 'foo'])
self.assertRaises(errors.ProgrammingError,
cur.executemany,
"INSERT INTO t1 1 %s", [(1,), (2,)])
cur.executemany("SELECT SHA1(%s)", [('foo',), ('bar',)])
self.assertEqual(None, cur.fetchone())
def test_executemany(self):
tbl = 'myconnpy_cursor'
self.setup_table(self.cnx, tbl)
stmt_insert = "INSERT INTO {0} (col1,col2) VALUES (%s,%s)".format(tbl)
stmt_select = "SELECT col1,col2 FROM {0} ORDER BY col1".format(tbl)
cur = self._get_cursor(self.cnx)
res = cur.executemany(stmt_insert, [(1, 100), (2, 200), (3, 300)])
self.assertEqual(3, cur.rowcount)
res = cur.executemany("SELECT %s", [('f',), ('o',), ('o',)])
self.assertEqual(3, cur.rowcount)
data = [{'id': 2}, {'id': 3}]
stmt = "SELECT * FROM {0} WHERE col1 <= %(id)s".format(tbl)
cur.executemany(stmt, data)
self.assertEqual(5, cur.rowcount)
cur.execute(stmt_select)
self.assertEqual([(1, '100'), (2, '200'), (3, '300')],
cur.fetchall(), "Multi insert test failed")
data = [{'id': 2}, {'id': 3}]
stmt = "DELETE FROM {0} WHERE col1 = %(id)s".format(tbl)
cur.executemany(stmt, data)
self.assertEqual(2, cur.rowcount)
stmt = "TRUNCATE TABLE {0}".format(tbl)
cur.execute(stmt)
stmt = (
"/*comment*/INSERT/*comment*/INTO/*comment*/{0}(col1,col2)VALUES"
"/*comment*/(%s,%s/*comment*/)/*comment()*/ON DUPLICATE KEY UPDATE"
" col1 = VALUES(col1)"
).format(tbl)
cur.executemany(stmt, [(4, 100), (5, 200), (6, 300)])
self.assertEqual(3, cur.rowcount)
cur.execute(stmt_select)
self.assertEqual([(4, '100'), (5, '200'), (6, '300')],
cur.fetchall(), "Multi insert test failed")
stmt = "TRUNCATE TABLE {0}".format(tbl)
cur.execute(stmt)
stmt = (
"INSERT INTO/*comment*/{0}(col1,col2)VALUES"
"/*comment*/(%s,'/*100*/')/*comment()*/ON DUPLICATE KEY UPDATE "
"col1 = VALUES(col1)"
).format(tbl)
cur.executemany(stmt, [(4,), (5,)])
self.assertEqual(2, cur.rowcount)
cur.execute(stmt_select)
self.assertEqual([(4, '/*100*/'), (5, '/*100*/')],
cur.fetchall(), "Multi insert test failed")
cur.close()
def _test_callproc_setup(self, cnx):
self._test_callproc_cleanup(cnx)
stmt_create1 = (
"CREATE PROCEDURE myconnpy_sp_1 "
"(IN pFac1 INT, IN pFac2 INT, OUT pProd INT) "
"BEGIN SET pProd := pFac1 * pFac2; END;")
stmt_create2 = (
"CREATE PROCEDURE myconnpy_sp_2 "
"(IN pFac1 INT, IN pFac2 INT, OUT pProd INT) "
"BEGIN SELECT 'abc'; SELECT 'def'; SET pProd := pFac1 * pFac2; "
"END;")
stmt_create3 = (
"CREATE PROCEDURE myconnpy_sp_3"
"(IN pStr1 VARCHAR(20), IN pStr2 VARCHAR(20), "
"OUT pConCat VARCHAR(100)) "
"BEGIN SET pConCat := CONCAT(pStr1, pStr2); END;")
stmt_create4 = (
"CREATE PROCEDURE myconnpy_sp_4"
"(IN pStr1 VARCHAR(20), INOUT pStr2 VARCHAR(20), "
"OUT pConCat VARCHAR(100)) "
"BEGIN SET pConCat := CONCAT(pStr1, pStr2); END;")
try:
cur = cnx.cursor()
cur.execute(stmt_create1)
cur.execute(stmt_create2)
cur.execute(stmt_create3)
cur.execute(stmt_create4)
except errors.Error as err:
self.fail("Failed setting up test stored routine; {0}".format(err))
cur.close()
def _test_callproc_cleanup(self, cnx):
sp_names = ('myconnpy_sp_1', 'myconnpy_sp_2', 'myconnpy_sp_3',
'myconnpy_sp_4')
stmt_drop = "DROP PROCEDURE IF EXISTS {procname}"
try:
cur = cnx.cursor()
for sp_name in sp_names:
cur.execute(stmt_drop.format(procname=sp_name))
except errors.Error as err:
self.fail(
"Failed cleaning up test stored routine; {0}".format(err))
cur.close()
def test_callproc(self):
cur = self._get_cursor(self.cnx)
self.check_method(cur, 'callproc')
self.assertRaises(ValueError, cur.callproc, None)
self.assertRaises(ValueError, cur.callproc, 'sp1', None)
config = tests.get_mysql_config()
self.cnx.get_warnings = True
self._test_callproc_setup(self.cnx)
cur = self.cnx.cursor()
if tests.MYSQL_VERSION < (5, 1):
exp = ('5', '4', b'20')
else:
exp = (5, 4, 20)
result = cur.callproc('myconnpy_sp_1', (exp[0], exp[1], 0))
self.assertEqual(exp, result)
if tests.MYSQL_VERSION < (5, 1):
exp = ('6', '5', b'30')
else:
exp = (6, 5, 30)
result = cur.callproc('myconnpy_sp_2', (exp[0], exp[1], 0))
self.assertTrue(isinstance(cur._stored_results, list))
self.assertEqual(exp, result)
exp_results = [
('abc',),
('def',)
]
for i, result in enumerate(cur.stored_results()):
self.assertEqual(exp_results[i], result.fetchone())
exp = ('ham', 'spam', 'hamspam')
result = cur.callproc('myconnpy_sp_3', (exp[0], exp[1], 0))
self.assertTrue(isinstance(cur._stored_results, list))
self.assertEqual(exp, result)
exp = ('ham', 'spam', 'hamspam')
result = cur.callproc('myconnpy_sp_4',
(exp[0], (exp[1], 'CHAR'), (0, 'CHAR')))
self.assertTrue(isinstance(cur._stored_results, list))
self.assertEqual(exp, result)
cur.close()
self._test_callproc_cleanup(self.cnx)
def test_fetchone(self):
cur = self._get_cursor(self.cnx)
self.assertEqual(None, cur.fetchone())
cur = self.cnx.cursor()
cur.execute("SELECT BINARY 'ham'")
exp = (b'ham',)
self.assertEqual(exp, cur.fetchone())
self.assertEqual(None, cur.fetchone())
cur.close()
def test_fetchmany(self):
"""MySQLCursor object fetchmany()-method"""
cur = self._get_cursor(self.cnx)
self.assertEqual([], cur.fetchmany())
tbl = 'myconnpy_fetch'
self.setup_table(self.cnx, tbl)
stmt_insert = (
"INSERT INTO {table} (col1,col2) "
"VALUES (%s,%s)".format(table=tbl))
stmt_select = (
"SELECT col1,col2 FROM {table} "
"ORDER BY col1 DESC".format(table=tbl))
cur = self.cnx.cursor()
nrrows = 10
data = [(i, str(i * 100)) for i in range(1, nrrows+1)]
cur.executemany(stmt_insert, data)
cur.execute(stmt_select)
exp = [(10, '1000'), (9, '900'), (8, '800'), (7, '700')]
rows = cur.fetchmany(4)
self.assertTrue(tests.cmp_result(exp, rows),
"Fetching first 4 rows test failed.")
exp = [(6, '600'), (5, '500'), (4, '400')]
rows = cur.fetchmany(3)
self.assertTrue(tests.cmp_result(exp, rows),
"Fetching next 3 rows test failed.")
exp = [(3, '300'), (2, '200'), (1, '100')]
rows = cur.fetchmany(3)
self.assertTrue(tests.cmp_result(exp, rows),
"Fetching next 3 rows test failed.")
self.assertEqual([], cur.fetchmany())
cur.close()
def test_fetchall(self):
cur = self._get_cursor(self.cnx)
self.assertRaises(errors.InterfaceError, cur.fetchall)
tbl = 'myconnpy_fetch'
self.setup_table(self.cnx, tbl)
stmt_insert = (
"INSERT INTO {table} (col1,col2) "
"VALUES (%s,%s)".format(table=tbl))
stmt_select = (
"SELECT col1,col2 FROM {table} "
"ORDER BY col1 ASC".format(table=tbl))
cur = self.cnx.cursor()
cur.execute("SELECT * FROM {table}".format(table=tbl))
self.assertEqual([], cur.fetchall(),
"fetchall() with empty result should return []")
nrrows = 10
data = [(i, str(i * 100)) for i in range(1, nrrows+1)]
cur.executemany(stmt_insert, data)
cur.execute(stmt_select)
self.assertTrue(tests.cmp_result(data, cur.fetchall()),
"Fetching all rows failed.")
self.assertEqual(None, cur.fetchone())
cur.close()
def test_raise_on_warning(self):
self.cnx.raise_on_warnings = True
cur = self._get_cursor(self.cnx)
cur.execute("SELECT 'a' + 'b'")
try:
cur.execute("SELECT 'a' + 'b'")
cur.fetchall()
except errors.DatabaseError:
pass
else:
self.fail("Did not get exception while raising warnings.")
def test__str__(self):
cur = self._get_cursor(self.cnx)
self.assertEqual("CMySQLCursor: (Nothing executed yet)",
cur.__str__())
cur.execute("SELECT VERSION()")
cur.fetchone()
self.assertEqual("CMySQLCursor: SELECT VERSION()",
cur.__str__())
stmt = "SELECT VERSION(),USER(),CURRENT_TIME(),NOW(),SHA1('myconnpy')"
cur.execute(stmt)
cur.fetchone()
self.assertEqual("CMySQLCursor: {0}..".format(stmt[:40]),
cur.__str__())
cur.close()
def test_column_names(self):
cur = self._get_cursor(self.cnx)
stmt = "SELECT NOW() as now, 'The time' as label, 123 FROM dual"
exp = (b'now', 'label', b'123')
cur.execute(stmt)
cur.fetchone()
self.assertEqual(exp, cur.column_names)
cur.close()
def test_statement(self):
cur = CMySQLCursor(self.cnx)
exp = 'SELECT * FROM ham'
cur._executed = exp
self.assertEqual(exp, cur.statement)
cur._executed = ' ' + exp + ' '
self.assertEqual(exp, cur.statement)
cur._executed = b'SELECT * FROM ham'
self.assertEqual(exp, cur.statement)
def test_with_rows(self):
cur = CMySQLCursor(self.cnx)
self.assertFalse(cur.with_rows)
cur._description = ('ham', 'spam')
self.assertTrue(cur.with_rows)
def tests_nextset(self):
cur = CMySQLCursor(self.cnx)
stmt = "SELECT 'result', 1; SELECT 'result', 2; SELECT 'result', 3"
cur.execute(stmt)
self.assertEqual([('result', 1)], cur.fetchall())
self.assertTrue(cur.nextset())
self.assertEqual([('result', 2)], cur.fetchall())
self.assertTrue(cur.nextset())
self.assertEqual([('result', 3)], cur.fetchall())
self.assertEqual(None, cur.nextset())
tbl = 'myconnpy_nextset'
stmt = "SELECT 'result', 1; INSERT INTO {0} () VALUES (); " \
"SELECT * FROM {0}".format(tbl)
self.setup_table(self.cnx, tbl)
cur.execute(stmt)
self.assertEqual([('result', 1)], cur.fetchall())
try:
cur.nextset()
except errors.Error as exc:
self.assertEqual(errorcode.CR_NO_RESULT_SET, exc.errno)
self.assertEqual(1, cur._affected_rows)
self.assertTrue(cur.nextset())
self.assertEqual([(1, None, 0)], cur.fetchall())
self.assertEqual(None, cur.nextset())
cur.close()
self.cnx.rollback()
def tests_execute_multi(self):
tbl = 'myconnpy_execute_multi'
stmt = "SELECT 'result', 1; INSERT INTO {0} () VALUES (); " \
"SELECT * FROM {0}".format(tbl)
self.setup_table(self.cnx, tbl)
multi_cur = CMySQLCursor(self.cnx)
results = []
exp = [
(u"SELECT 'result', 1", [(u'result', 1)]),
(u"INSERT INTO {0} () VALUES ()".format(tbl), 1, 1),
(u"SELECT * FROM {0}".format(tbl), [(1, None, 0)]),
]
for cur in multi_cur.execute(stmt, multi=True):
if cur.with_rows:
results.append((cur.statement, cur.fetchall()))
else:
results.append(
(cur.statement, cur._affected_rows, cur.lastrowid)
)
self.assertEqual(exp, results)
cur.close()
self.cnx.rollback()
class CExtMySQLCursorBufferedTests(tests.CMySQLCursorTests):
def _get_cursor(self, cnx=None):
if not cnx:
cnx = CMySQLConnection(**self.config)
self.cnx.buffered = True
return CMySQLCursorBuffered(connection=cnx)
def test___init__(self):
self.assertRaises(errors.InterfaceError, CMySQLCursorBuffered,
connection='ham')
cur = self._get_cursor(self.cnx)
self.assertTrue(hex(id(self.cnx)).upper()[2:]
in repr(cur._cnx).upper())
def test_execute(self):
self.cnx.get_warnings = True
cur = self._get_cursor(self.cnx)
self.assertEqual(None, cur.execute(None, None))
self.assertEqual(True,
isinstance(cur, CMySQLCursorBuffered))
cur.execute("SELECT 1")
self.assertEqual((1,), cur.fetchone())
def test_raise_on_warning(self):
self.cnx.raise_on_warnings = True
cur = self._get_cursor(self.cnx)
self.assertRaises(errors.DatabaseError,
cur.execute, "SELECT 'a' + 'b'")
def test_with_rows(self):
cur = self._get_cursor(self.cnx)
cur.execute("SELECT 1")
self.assertTrue(cur.with_rows)
class CMySQLCursorRawTests(tests.CMySQLCursorTests):
def _get_cursor(self, cnx=None):
if not cnx:
cnx = CMySQLConnection(**self.config)
return CMySQLCursorRaw(connection=cnx)
def test_fetchone(self):
cur = self._get_cursor(self.cnx)
self.assertEqual(None, cur.fetchone())
cur.execute("SELECT 1, 'string', MAKEDATE(2010,365), 2.5")
exp = (b'1', b'string', b'2010-12-31', b'2.5')
self.assertEqual(exp, cur.fetchone())

View File

@@ -1,10 +0,0 @@
[connector_python]
user=root
password=mypass
database=cpydata
port=10000
[connector_python]
user=mysql
password=mypass
database=duplicate_data

View File

@@ -1,13 +0,0 @@
[group1]
option1 = 5
option2 = 10
[group2]
option1 = 10
option2 = 20
[group3]
option3 = 100
[mysql]
user = ham

View File

@@ -1,13 +0,0 @@
[group1]
option1 = 15
option2 = 20
[group2]
option1 = 20
option2 = 30
[group4]
option3 = 200
[client]
user = spam

View File

@@ -1,29 +0,0 @@
[client]
password=12345
port=1000
socket=/var/run/mysqld/mysqld.sock
ssl-ca=dummyCA
ssl-cert=dummyCert
ssl-key=dummyKey
ssl-cipher=AES256-SHA:CAMELLIA256-SHA
[mysqld_safe]
socket=/var/run/mysqld/mysqld1.sock
nice=0
[mysqld]
user=mysql
pid-file=/var/run/mysqld/mysqld.pid
socket=/var/run/mysqld/mysqld2.sock
port=1001
basedir=/usr
datadir=/var/lib/mysql
tmpdir=/tmp
lc-messages-dir=/usr/share/mysql
skip-external-locking
bind-address=127.0.0.1
log_error=/var/log/mysql/error.log

View File

@@ -1,14 +0,0 @@
[pooling]
pool_size = 1
pool_name = my_pool
pool_reset_session = True
[fabric]
fabric_host = fabric.example.com
fabric_connect_delay = 3
fabric_ssl_ca = /path/to/ssl
fabric_password = foo
[failover]
failover = ({'port': 3306, 'pool_name': 'failA' }, {'port': 3307, 'pool_name': 'failB' },)

View File

@@ -1,20 +0,0 @@
-----BEGIN CERTIFICATE-----
MIIDVzCCAj+gAwIBAgIJAIUsZ/vX9kOGMA0GCSqGSIb3DQEBBQUAMEIxJTAjBgNV
BAsMHE15U1FMQ29ubmVjdG9yUHl0aG9uIFJvb3QgQ0ExGTAXBgNVBAMMEE15Q29u
blB5IFJvb3QgQ0EwHhcNMTMwMzI2MTUzNTUyWhcNMjIwNDE0MTUzNTUyWjBCMSUw
IwYDVQQLDBxNeVNRTENvbm5lY3RvclB5dGhvbiBSb290IENBMRkwFwYDVQQDDBBN
eUNvbm5QeSBSb290IENBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA
qWcX9kD+b8c3hkPtPlIgwTsfGvhm/bJ64RHjCtQc2pi/fv9hlcryor8tWmdCCcw7
ajg5n/QAIJ8crD5D0kheGEnWVI7dyVxZVfT3CiKuS+GBxuQP2ejJi4aDGh2McVv4
aq1dXRqf2YWkM8PUjM0lzUD9MC9S4APtP6ux0TBhz5rv2ZWdg2EAjAl7Q56KM5m6
odpF+Z1ExnfVpNzWnpvlYHJ+GhbVWb2F0NbqBTmz4OLEAxU/O2fo43dwVlHp+yNd
ib2V+VxeeyZmTt1CIeK6DStAiKdNLN5/N/+2FHZ9/XcA6qqxLFLeuTIySlPmuaX6
u2C8tmOWp99TCUL+GZ2iBwIDAQABo1AwTjAdBgNVHQ4EFgQU1objOGh5rgtBTmjK
gPkN6SgXl64wHwYDVR0jBBgwFoAU1objOGh5rgtBTmjKgPkN6SgXl64wDAYDVR0T
BAUwAwEB/zANBgkqhkiG9w0BAQUFAAOCAQEAWgHZzUo8oGP7YxMn9YACdbipTRYU
IzGF+Cf0ueXktcEDbq7AIa6MsxXTp8pFOObvLiiecrMngYlqfHlYPL2HG+zOLDig
nmkO4pGwTqCDZHO4aYBdiVMlaxSpxMX9R/kFYRP1P4AGLOp66FirNO5iLNlTIjpf
PGebF+k0B1zUSUPsrZfa/d29XcJxBaw7aEOhARQYsymItasnTdcKvjZp1ahGnZYz
yCDtJjVbXK/4qEtiSA4qcV1HrNuHmhZEwWahntLqo++x3oLK7DrWfHwTX5gHMyv2
DGTggnNfB8uzzNe3giT0j6ie9DJEnvv1hB0GpUToUNECusrKsYnWLdJkIA==
-----END CERTIFICATE-----

View File

@@ -1,27 +0,0 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAqWcX9kD+b8c3hkPtPlIgwTsfGvhm/bJ64RHjCtQc2pi/fv9h
lcryor8tWmdCCcw7ajg5n/QAIJ8crD5D0kheGEnWVI7dyVxZVfT3CiKuS+GBxuQP
2ejJi4aDGh2McVv4aq1dXRqf2YWkM8PUjM0lzUD9MC9S4APtP6ux0TBhz5rv2ZWd
g2EAjAl7Q56KM5m6odpF+Z1ExnfVpNzWnpvlYHJ+GhbVWb2F0NbqBTmz4OLEAxU/
O2fo43dwVlHp+yNdib2V+VxeeyZmTt1CIeK6DStAiKdNLN5/N/+2FHZ9/XcA6qqx
LFLeuTIySlPmuaX6u2C8tmOWp99TCUL+GZ2iBwIDAQABAoIBAAKXtFMtfXdieiQQ
6BGbGis652f3Q0RAtga5ylrBEkv6KHweFnU/bOU2vc/zYpxZxtMCV0duaY4WQU8V
iN4wA1il0KTsptJNGoTpQdqi2z4IDn9nwCJaoLME9P6yUxLtEGk5jAM/xBCFLhUo
uxkIjrqMcxOIteD9zmS6EPedoPGXbBFK2jBheArszZ/fiNhi7D2w03/s/Dhu14Px
5gjG5f+A/lS0l81RC5aeUt+wghA5y7TxY20fN1QU+XX2+Oft/HBq6xNloMlmPhzN
loi952HlWLZS31QJRgEhXZ3aJMHDQ3z9I4M6RfdngW2aJTbuJq/weFgN0Z8ogDLK
k/kuTfECgYEA2F5uRlUEW/0MKPrd10q5Ln3i3o0dQmW/QaPZ+SCjGan7xiC8Hm/2
awkZIIaHQThkgRtxgsLOY+o7NVWkzTeLBlCKl12O0TQ3ZofwXdWPdI2b7adiFnEd
6/htxQd90En7BgNls39j9bK7UVDDilJrRDKvyNzQKwHP95QRxJellJkCgYEAyG5p
lB9j78CLWL9lZZgG7Xu/+DR63gceLLBAYclIfHzIb5B49TakasEgsT6JKbqmmwcC
VXs+SSw0b1dYaFajOL9ceMkOFEn9KV5bESKcPJ2/JxBW6e5j6i4eo+oQxTTiAn75
UEcmPx8aBCtxhj4LFPKSwzi8mJNliRH2lLAYb58CgYEAlRrGLauq3GWOurLeq92v
ra1M6YcfkcEiQu7SaI8oNqhgfBHU8bjAfNSBP1vV24ksIZiy6aSrrEkfUkrZzh4n
rUtVpqfvopW0U/D8IP3p5S0tNmIyAzsinpnNs4jNF/vThDpVHJR+YzQvSAM7LZhM
mWvAndAlmG2gToH4mJzUm4kCgYBKFk4ee4/0Uobvsifn6s88v46RT8zO/3CO8kOK
Id4Sbgmk+5FKiv0xnNvZyJTpAN6O1YNuV5UJdTaYpX+/aa8BzfJ/j0oOA995iDA/
YDzCR0keRnLqG72BFbUrv9ydGNQmOgssOnCPyo5SVkCrb4mnH5dSZEmKWImipiow
gfs2XwKBgQDSjbMlJme1fwNEt7EvwLJ6Zd4wSLs70IWvcX3k0g4PMhSj9J1zXRP+
wpOZCa4GW2y21t5dpHG2B+a9Sd+z0/NMSSBZ8SUfrbZza3gC6cJyPoBYy7w/PFx3
CgHcWRVI3n6+dkMYzpu2J1zzB2y0aiBE4icDq5+Uq7kO2OIytPVnHA==
-----END RSA PRIVATE KEY-----

View File

@@ -1,18 +0,0 @@
-----BEGIN CERTIFICATE-----
MIIC9TCCAd0CAQEwDQYJKoZIhvcNAQEFBQAwQjElMCMGA1UECwwcTXlTUUxDb25u
ZWN0b3JQeXRob24gUm9vdCBDQTEZMBcGA1UEAwwQTXlDb25uUHkgUm9vdCBDQTAe
Fw0xMzAzMjYxNTM1NTJaFw0yMjA0MTQxNTM1NTJaMD8xKTAnBgNVBAsMIE15U1FM
Q29ubmVjdG9yUHl0aG9uIENsaWVudCBDZXJ0MRIwEAYDVQQDDAlsb2NhbGhvc3Qw
ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDXbL7sr/k/W4LwwzTKJj5i
1QtcZL0tMyBhAwuI7XQVyJBVvY7dRUM+G30ADOcUw5DscYbkkVu3L2NtsnmuyB8o
0Y5bbHpTv4xTrVfsQuDkMLe+/LwFfL7XrY1Bm13xdEn345b6edfvhre7eatCgIaG
IKfFr5JDv5oN4faGEJpqYahE/WdxM7zv6xb7Wx+yqLlezldU34VcLcghi8zfDkxb
Fb4cZSgko/9RT7lTUGBJSSgITnq3Re0qANah7UbqFkTM2wfltoXGerbWMYuzOfQo
5r0FiScjuvACkDALHAdUbX4UbXasArqpGovyVqHp4OWu3FWRfcCUnxAxfj3G3x79
AgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAFi+U6Fyc1L0qCTCiMvUMQuXacnOMH4q
rHm7qDKkHHcMMGsspNXvLcVKEwJrX3dhP3dZ52eKyFsOjuTkO9eU5H8V2alO8iGD
Zb6vHT/pQRInoc39SVDFx1QnJ7RlC2Z99xzncHMQChSlDCC+Lft/K5am7vXFwQ3e
icfLqmR5hz6nc+opnPc7WbQu/cc7PesP5uroyKScYoqAiDJ2cKQJQFPM4Cvt/KZ3
22H/yCyQNkplIcrlQRF+l+sInNlJZr36INF0o91GcucyuLQzOXUn0L5eAyFzA9RQ
8xkVztqRN++CgbGAhqIt8ERBtxBvCpNxuFpgm4dPKCTLm+r7fJcKwDI=
-----END CERTIFICATE-----

View File

@@ -1,27 +0,0 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEA12y+7K/5P1uC8MM0yiY+YtULXGS9LTMgYQMLiO10FciQVb2O
3UVDPht9AAznFMOQ7HGG5JFbty9jbbJ5rsgfKNGOW2x6U7+MU61X7ELg5DC3vvy8
BXy+162NQZtd8XRJ9+OW+nnX74a3u3mrQoCGhiCnxa+SQ7+aDeH2hhCaamGoRP1n
cTO87+sW+1sfsqi5Xs5XVN+FXC3IIYvM3w5MWxW+HGUoJKP/UU+5U1BgSUkoCE56
t0XtKgDWoe1G6hZEzNsH5baFxnq21jGLszn0KOa9BYknI7rwApAwCxwHVG1+FG12
rAK6qRqL8lah6eDlrtxVkX3AlJ8QMX49xt8e/QIDAQABAoIBAQCjSd5+cfSvvaHG
9XAyOkLXjz0JT6LFfBdy8Wfw5mwzhs9A7mo39qQ9k4BwZVdTOdnEH1lsL3IhrF3l
bH8nqLFVs2IAkn02td6cHqyifR8SWIsuzUuHrULLINYNgML4nnji2TQ7r9epy6fB
Bzx1MA7H5EDHa4mmqLkRBNJkVHl3YCGM25tXyhixC5MsNdSpTwLMvv/RVLqsHtH6
WZ3P8VZi/iOk28TQwLcFTQz4g6RM3jO/1O9tXhob9g1iUoLNd3mLR3+sdkhHf5bU
ttEzxvfVl4Fe0463J4I/JeofGtDBkWgR4UI5ZVfC0xLvmVA4J3cxgUeAKsIwuqQT
9Gi4MDOBAoGBAP6MGCwZUmVqoaqaNF/XckwieJctYLUxhf/KA9S3pq2Y4PPFb7FO
srqn90c2Qb4o13iZzak9rPKUVKwcL+VYknrVGb1ALyWySI7WEaUzsXLIGF2w010l
TNUyL82NynGUx3/4gxvJf/K9weVkTU7KK2tfdB+ridv1ZcSn9bETMvVJAoGBANin
fdqLh8tFMqTsc+bMvlogzns9y+MluJeqz+On706sVR6XsEF8LtzcnHAwOYFef6h5
cgrKGzfWaz88tNdgB82p/smLQcz4ouFAzTBX3y/+LG/+ybbkR9a2sO+gHA1eAukB
Ia5q/t5jI0XiTa4lVoj2IJK7/hBjIYYBLA2TKQAVAoGBAPP6k7CxFKjga9R5uXmj
p4oSAEPm2qrRrP5fQwzAeqIpxnPg6g2owObn17wJ5Tm/K8gMo3N0CjD4u6+71Kyf
GMdjOiiLPKWFHMbLqF4QDiVWZQRoWC8PcXVnhSogncoAMLgYGpKnsFuaRh745KCA
Zt2jwEoawShzLfgwhO4U2OMBAoGAULfuctsjZ79LRBj4gZfsn6WzaEU4zlNCd/di
5t2tkjEwsWowd+VtjEoBWucMtb9gboN40r5D78TKRlA2zDtyDNT2IV7p0BUeki/T
gtxqQfY/1iYmPybEASIlv9F2QiCxkuAiDVq9xFtJTAMpj+VHXVXeAu1Zlf9pAQU0
xYX7c5UCgYA8Iux1dO7bakTlqwFUQCMM5IlzJJVT90Z8JQOCFk6a7bzTdOkyxYg2
BxiGjiFhNer6UshTNZj2svdUvVh9yH/iRGEP6eQAZR1AXIr1YazNmaG7tjIEZ4Yw
zx8gdGTIDYBDChFQmJIB9Y7iNF8bu8JmyVuo2SJHhIVyXN/cM9T6gg==
-----END RSA PRIVATE KEY-----

View File

@@ -1,18 +0,0 @@
-----BEGIN CERTIFICATE-----
MIIC9TCCAd0CAQEwDQYJKoZIhvcNAQEFBQAwQjElMCMGA1UECwwcTXlTUUxDb25u
ZWN0b3JQeXRob24gUm9vdCBDQTEZMBcGA1UEAwwQTXlDb25uUHkgUm9vdCBDQTAe
Fw0xMzAzMjYxNTM1NTJaFw0yMjA0MTQxNTM1NTJaMD8xKTAnBgNVBAsMIE15U1FM
Q29ubmVjdG9yUHl0aG9uIFNlcnZlciBDZXJ0MRIwEAYDVQQDDAlsb2NhbGhvc3Qw
ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDv6WQ/Ssum4RchjXSlwbcB
au3WNccfodThBOAM27AOnJQWIjG4e5s9H7lLznI+VF5MgUbgbp/yz4D+CrSFvLgU
4xxzd1/SVbnzRJ5iD2EmaZPjoMkBmvDRd4ow6IdFN80Fpwxij6fUBHdRkyXyiYsG
FE94PQCyD1R47LSubd/gfcjXw8Bt5cWqcopiolZ01bYuMzeZIw0et9gf6Iih2Zh1
bs9RthHfL3BfN4knljF3XmRQhfsc4w3MvdulX4mcfzS+E+keOOgPjfjo9KVCD1Zl
F00wQdbSCWzf9uCP4OpKJGURyMQEmGMFPBOP98kqns1CqaE0PxKOpbcTX86nSEO5
AgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAFy4ONx0zFYgVNL046lfRmimyRf1gbmB
pyyug9eW6QuuTfqbzFWOYZY8pG2lzKnHNUMmgzMNMpiRLRJ38Dj5rApg+7OkiTT+
l4DMIR/YblJryEvx6tNUq2Cu9GXKW2qrGJO3XVniuBpmg1srugdwyxS+LdFofgBc
I4cKIDuXYATUpOFhEsFbMY6tGVeOXQN2jSWtUj6+mKiUWMyr+5NYD8xhjDV7q4GH
JfQqWFzw7prtSYzwB8lc0PM2SLwxeE9cQUYN/UkW8HRxM7Ft5KyyXUk+2Jg61sZ2
QxMCV6NAGYMX40WRDqIZbs9AbHWoCxEwoXWtcmNb0GInsk39lFMJqw4=
-----END CERTIFICATE-----

View File

@@ -1,27 +0,0 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEA7+lkP0rLpuEXIY10pcG3AWrt1jXHH6HU4QTgDNuwDpyUFiIx
uHubPR+5S85yPlReTIFG4G6f8s+A/gq0hby4FOMcc3df0lW580SeYg9hJmmT46DJ
AZrw0XeKMOiHRTfNBacMYo+n1AR3UZMl8omLBhRPeD0Asg9UeOy0rm3f4H3I18PA
beXFqnKKYqJWdNW2LjM3mSMNHrfYH+iIodmYdW7PUbYR3y9wXzeJJ5Yxd15kUIX7
HOMNzL3bpV+JnH80vhPpHjjoD4346PSlQg9WZRdNMEHW0gls3/bgj+DqSiRlEcjE
BJhjBTwTj/fJKp7NQqmhND8SjqW3E1/Op0hDuQIDAQABAoIBAQCyfCuVntq2E532
21td+ilhh6DcDfRPh0FuCwd46XQo2rqdYOEmw+bxaYmcaUG7N19UgZUuYX7j0RbB
aUt2d7ln6LMBAF2siRSndHR0tcZsIn3hCnygkhn5bHrF+iixCVuhie7/4KpWZOA0
M0o3D7b7Vd7tsEy1LAyHTmr5nkrBosIpLXQvnjj8kF6MOQW09/72l7eiFwnRQ3yW
eUn8l+vkIRpYzI/l1MFnj1lcGeDKRDFJMXZV7OropJaQabWuGyaddizP8ihhU/Vf
VEHFJnW+AS3JpMO2Bf8ICMGu+0d4AJsNPW7KNNlqv79Nws2ijl6bcWz+E7NAG55C
DY1LU5iBAoGBAPjf0QRpdDLd9+ntAkJMfSwhl0yqarZPuaGsKWnG5C7BPcj3wLaP
GHn3CI0SF0JiwN0zOrLv821im5Wr5Ux/OoSDdIR/y9Vp8joTno0+7MUU5zuN93r+
8EAHY5GEZoJ0ndU7xP50jEYq0AZinginyqtGyL6HpJL3VJoL14cCYYuRAoGBAPbH
4bHPWSEJY3X8Hq4KRbtyyTfT1s7zFrvDZHkWFH+tVD+DsKpmRQ5A0lWVBPhPaS1Y
GJcu9h9VKSEjBgM2ZJpB8A4zJGYIgsPXQTOQm/s9fbWj76zJ8r2z4W7P2Ry9U1e5
cwZnQgLoPvBL7IHm4J92RfoRZO5IohRyUDaAdpGpAoGAIL3hU8FD5kVJjl7+Axbp
CNtKem2ZKG8IrvplYGMoNfZ6WGwv0FS3FaSoXVbZ9IPld7R7rnre/a8RZPl+azf5
zOE2fRALEwKjOXzHSTHUGIGNgkpFGstbdDEEqmpOyi7pbNo2KnvO0JRlVdG3lM/u
W+YuFtLllegwGywfqMVpa+ECgYEAp4/StFv4xdDNIuh8oGnDLWLkM674FO7DydwD
FaCjbInxQWsWgq0MSIBFEO0tQbkRzkMZ91VgsqetVJ2mUHoXVxJcgBfDqDAxMe6v
i+atsqru922HqMg6tQo1kHs6jSQUOeVmr7te/ABb8+dpgE6WyE+Tdhdnc9AHlWCF
DGyvlXkCgYB2OYDiXSne2DYglcEk2pyr6h5sQRuKuYXnq7NWFTYIiLb/Bz6g9oLs
fV5LkBfCWRSg3PoR8hX3F8PC1i2G+50gXucoFdvlvS5bawPABxtYGqhyz63awNud
JnJIdqY3vLoUWeEZF3HmdBMN8jy6Am7pMynHFvoEjMBRmGNOjedZrA==
-----END RSA PRIVATE KEY-----

View File

@@ -1,80 +0,0 @@
# -*- coding: utf-8 -*-
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Incur., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
import mysql.connector
from tests import foreach_cnx, cnx_config
import tests
class Bug21449207(tests.MySQLConnectorTests):
def setUp(self):
self.tbl = 'Bug21449207'
cnx = mysql.connector.connect(**tests.get_mysql_config())
cnx.cmd_query("DROP TABLE IF EXISTS %s" % self.tbl)
create_table = (
"CREATE TABLE {0} ("
"id INT PRIMARY KEY, "
"a LONGTEXT "
") ENGINE=Innodb DEFAULT CHARSET utf8".format(self.tbl))
cnx.cmd_query(create_table)
cnx.close()
def tearDown(self):
cnx = mysql.connector.connect(**tests.get_mysql_config())
cnx.cmd_query("DROP TABLE IF EXISTS %s" % self.tbl)
cnx.close()
@foreach_cnx()
def test_uncompressed(self):
cur = self.cnx.cursor()
exp = 'a' * 15 + 'TheEnd'
insert = "INSERT INTO {0} (a) VALUES ('{1}')".format(self.tbl, exp)
cur.execute(insert)
cur.execute("SELECT a FROM {0}".format(self.tbl))
row = cur.fetchone()
self.assertEqual(exp, row[0])
self.assertEqual(row[0][-20:], exp[-20:])
@foreach_cnx()
def test_50k_compressed(self):
cur = self.cnx.cursor()
exp = 'a' * 50000 + 'TheEnd'
insert = "INSERT INTO {0} (a) VALUES ('{1}')".format(self.tbl, exp)
cur.execute(insert)
cur.execute("SELECT a FROM {0}".format(self.tbl))
row = cur.fetchone()
self.assertEqual(exp, row[0])
self.assertEqual(row[0][-20:], exp[-20:])
@foreach_cnx()
def test_16M_compressed(self):
cur = self.cnx.cursor()
exp = 'a' * 16777210 + 'TheEnd'
insert = "INSERT INTO {0} (a) VALUES ('{1}')".format(self.tbl, exp)
cur.execute(insert)
cur.execute("SELECT a FROM {0}".format(self.tbl))
row = cur.fetchone()
self.assertEqual(exp, row[0])
self.assertEqual(row[0][-20:], exp[-20:])

View File

@@ -1,59 +0,0 @@
# -*- coding: utf-8 -*-
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Incur., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
import os.path
import unittest
import mysql.connector
from tests import foreach_cnx, cnx_config
import tests
DATA_FILE = os.path.join('tests', 'data', 'random_big_bin.csv')
class Bug21449996(tests.MySQLConnectorTests):
def setUp(self):
self.table_name = 'Bug21449996'
cnx = mysql.connector.connect(**tests.get_mysql_config())
cnx.cmd_query("DROP TABLE IF EXISTS %s" % self.table_name)
cnx.cmd_query("CREATE TABLE %s (c1 BLOB)" % self.table_name)
cnx.close()
def tearDown(self):
cnx = mysql.connector.connect(**tests.get_mysql_config())
cnx.cmd_query("DROP TABLE IF EXISTS %s" % self.table_name)
cnx.close()
@foreach_cnx()
def test_load_data_compressed(self):
try:
cur = self.cnx.cursor()
sql = "LOAD DATA LOCAL INFILE '%s' INTO TABLE %s" % (
DATA_FILE, self.table_name)
cur.execute(sql)
except mysql.connector.errors.InterfaceError as exc:
self.fail(exc)
cur.execute("SELECT COUNT(*) FROM %s" % self.table_name)
self.assertEqual(11486, cur.fetchone()[0])

View File

@@ -1,76 +0,0 @@
# -*- coding: utf-8 -*-
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Incur., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
""" BUG21879859
"""
import os.path
import unittest
import mysql.connector
from mysql.connector import Error
from tests import foreach_cnx, cnx_config
import tests
try:
from mysql.connector.connection_cext import CMySQLConnection
except ImportError:
# Test without C Extension
CMySQLConnection = None
class Bug21879859(tests.MySQLConnectorTests):
def setUp(self):
self.table = "Bug21879859"
self.proc = "Bug21879859_proc"
cnx = mysql.connector.connect(**tests.get_mysql_config())
cur = cnx.cursor()
cur.execute("DROP TABLE IF EXISTS {0}".format(self.table))
cur.execute("DROP PROCEDURE IF EXISTS {0}".format(self.proc))
cur.execute("CREATE TABLE {0} (c1 VARCHAR(1024))".format(self.table))
cur.execute(
"CREATE PROCEDURE {1}() BEGIN SELECT 1234; "
"SELECT t from {0}; SELECT '' from {0}; END".format(
self.table, self.proc
));
def tearDown(self):
cnx = mysql.connector.connect(**tests.get_mysql_config())
cur = cnx.cursor()
cur.execute("DROP TABLE IF EXISTS {0}".format(self.table))
cur.execute("DROP PROCEDURE IF EXISTS {0}".format(self.proc))
@cnx_config(consume_results=True)
@foreach_cnx()
def test_consume_after_callproc(self):
cur = self.cnx.cursor()
cur.execute("INSERT INTO {0} VALUES ('a'),('b'),('c')".format(self.table))
# expected to fail
self.assertRaises(Error, cur.callproc, self.proc)
try:
cur.close()
except mysql.connector.Error as exc:
self.fail("Failed closing: " + str(exc))

View File

@@ -1,58 +0,0 @@
# -*- coding: utf-8 -*-
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Incur., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
""" BUG21879914 Fix using C/Extension with only CA given
"""
import os.path
import unittest
import mysql.connector
from tests import foreach_cnx, cnx_config
import tests
try:
from mysql.connector.connection_cext import CMySQLConnection
except ImportError:
# Test without C Extension
CMySQLConnection = None
TEST_SSL = {
'ca': os.path.join(tests.SSL_DIR, 'tests_CA_cert.pem'),
'cert': os.path.join(tests.SSL_DIR, 'tests_client_cert.pem'),
'key': os.path.join(tests.SSL_DIR, 'tests_client_key.pem'),
}
OPTION_FILE = os.path.join('tests', 'data', 'option_files', 'my.cnf')
class Bug21879914(tests.MySQLConnectorTests):
def test_ssl_cipher_in_option_file(self):
config = tests.get_mysql_config()
config['ssl_ca'] = TEST_SSL['ca']
config['use_pure'] = False
cnx = mysql.connector.connect(**config)
cnx.cmd_query("SHOW STATUS LIKE 'Ssl_cipher'")
self.assertNotEqual(cnx.get_row()[1], '') # Ssl_cipher must have a value

View File

@@ -1,71 +0,0 @@
# -*- coding: utf-8 -*-
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Incur., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
""" BUG22545879 Fix reading and using ssl-cipher MySQL option
"""
import os.path
import unittest
import mysql.connector
from tests import foreach_cnx, cnx_config
import tests
try:
from mysql.connector.connection_cext import CMySQLConnection
except ImportError:
# Test without C Extension
CMySQLConnection = None
TEST_SSL = {
'ca': os.path.join(tests.SSL_DIR, 'tests_CA_cert.pem'),
'cert': os.path.join(tests.SSL_DIR, 'tests_client_cert.pem'),
'key': os.path.join(tests.SSL_DIR, 'tests_client_key.pem'),
}
OPTION_FILE = os.path.join('tests', 'data', 'option_files', 'my.cnf')
class Bug21449996(tests.MySQLConnectorTests):
@cnx_config(ssl_ca=TEST_SSL['ca'], ssl_cert=TEST_SSL['cert'], ssl_key=TEST_SSL['key'],
ssl_cipher="AES256-SHA:CAMELLIA256-SHA")
@foreach_cnx()
def test_cnx_argument_ssl_cipher(self):
self.assertIn('cipher', self.cnx._ssl)
self.assertEqual("AES256-SHA:CAMELLIA256-SHA", self.cnx._ssl['cipher'])
def test_ssl_cipher_in_option_file(self):
config = tests.get_mysql_config()
config['option_files'] = [OPTION_FILE]
cnx = mysql.connector.MySQLConnection()
cnx.config(**config)
self.assertIn('cipher', cnx._ssl)
self.assertEqual("AES256-SHA:CAMELLIA256-SHA", cnx._ssl['cipher'])
if CMySQLConnection:
cnx = CMySQLConnection()
cnx.config(**config)
self.assertIn('cipher', cnx._ssl)
self.assertEqual("AES256-SHA:CAMELLIA256-SHA", cnx._ssl['cipher'])

View File

@@ -1,746 +0,0 @@
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2009, 2015, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Module for managing and running a MySQL server"""
import sys
import os
import signal
import re
from shutil import rmtree
import subprocess
import logging
import time
import ctypes
import socket
import errno
import struct
try:
from ctypes import wintypes
except (ImportError, ValueError):
# We are not on Windows
pass
try:
from socketserver import (
ThreadingMixIn, TCPServer, BaseRequestHandler
)
except ImportError:
from SocketServer import (
ThreadingMixIn, TCPServer, BaseRequestHandler
)
TCPServer.allow_reuse_address = True
import tests
LOGGER = logging.getLogger(tests.LOGGER_NAME)
DEVNULL = open(os.devnull, 'w')
# MySQL Server executable name
if os.name == 'nt':
EXEC_MYSQLD = 'mysqld.exe'
else:
EXEC_MYSQLD = 'mysqld'
# MySQL client executable name
if os.name == 'nt':
EXEC_MYSQL = 'mysql.exe'
else:
EXEC_MYSQL = 'mysql'
def _convert_forward_slash(path):
"""Convert forward slashes with backslashes
This function replaces forward slashes with backslashes. This
is necessary using Microsoft Windows for location of files in
the option files.
Returns a string
"""
if os.name == 'nt':
nmpath = os.path.normpath(path)
return nmpath.replace('\\', '\\\\')
return path
def process_running(pid):
"""Check whether a process is running
This function takes the process ID or pid and checks whether it is
running. It works for Windows and UNIX-like systems.
Return True or False
"""
if os.name == 'nt':
# We are on Windows
process = subprocess.Popen(['tasklist'], stdout=subprocess.PIPE)
output, _ = process.communicate()
lines = [line.split(None, 2) for line in output.splitlines() if line]
for name, apid, _ in lines:
name = name.decode('utf-8')
if name == EXEC_MYSQLD and pid == int(apid):
return True
return False
# We are on a UNIX-like system
try:
os.kill(pid, 0)
except OSError:
return False
return True
def process_terminate(pid):
"""Terminates a process
This function terminates a running process using it's pid (process
ID), sending a SIGKILL on Posix systems and using ctypes.windll
on Windows.
Raises MySQLServerError on errors.
"""
if os.name == 'nt':
winkernel = ctypes.windll.kernel32
process = winkernel.OpenProcess(0x0001, 0, pid) # PROCESS_TERMINATE
winkernel.TerminateProcess(process, 1)
winkernel.CloseHandle(process)
else:
os.kill(pid, signal.SIGTERM)
def get_pid(pid_file):
"""Returns the PID read from the PID file
Returns None or int.
"""
try:
return int(open(pid_file, 'r').readline().strip())
except IOError as err:
LOGGER.debug("Failed reading pid file: %s", err)
return None
class MySQLServerError(Exception):
"""Exception for raising errors when managing a MySQL server"""
pass
class MySQLBootstrapError(MySQLServerError):
"""Exception for raising errors around bootstrapping a MySQL server"""
pass
class MySQLServerBase(object):
"""Base for classes managing a MySQL server"""
def __init__(self, basedir, option_file=None, sharedir=None):
self._basedir = basedir
self._sbindir = None
self._sharedir = sharedir
self._scriptdir = None
self._process = None
self._lc_messages_dir = None
self._init_mysql_install()
self._version = self._get_version()
if option_file and os.access(option_file, 0):
MySQLBootstrapError("Option file not accessible: {name}".format(
name=option_file))
self._option_file = option_file
def _init_mysql_install(self):
"""Checking MySQL installation
Check the MySQL installation and set the directories where
to find binaries and SQL bootstrap scripts.
Raises MySQLBootstrapError when something fails.
"""
# Locate mysqld, mysql binaries
LOGGER.info("Locating mysql binaries (could take a while)")
files_to_find = [EXEC_MYSQL, EXEC_MYSQLD]
for root, dirs, files in os.walk(self._basedir):
if self._sbindir:
break
for afile in files:
if (afile == EXEC_MYSQLD and
os.access(os.path.join(root, afile), 0)):
self._sbindir = root
files_to_find.remove(EXEC_MYSQLD)
elif (afile == EXEC_MYSQL and
os.access(os.path.join(root, afile), 0)):
self._bindir = root
files_to_find.remove(EXEC_MYSQL)
if not files_to_find:
break
if not self._sbindir:
raise MySQLBootstrapError(
"MySQL binaries not found under {0}".format(self._basedir))
# Try to locate errmsg.sys and mysql_system_tables.sql
if not self._sharedir:
match = self._get_mysqld_help_info(r'^lc-messages-dir\s+(.*)\s*$')
if match:
self._sharedir = match[0]
if not self._sharedir:
raise MySQLBootstrapError("Failed getting share folder. "
"Use --with-mysql-share.")
LOGGER.debug("Using share folder: %s", self._sharedir)
found = False
for root, dirs, files in os.walk(self._sharedir):
if found:
break
for afile in files:
if afile == 'errmsg.sys' and 'english' in root:
self._lc_messages_dir = os.path.abspath(
os.path.join(root, os.pardir)
)
elif afile == 'mysql_system_tables.sql':
self._scriptdir = root
if not self._lc_messages_dir or not self._scriptdir:
raise MySQLBootstrapError(
"errmsg.sys and mysql_system_tables.sql not found"
" under {0}".format(self._sharedir))
LOGGER.debug("Location of MySQL Server binaries: %s", self._sbindir)
LOGGER.debug("Error messages: %s", self._lc_messages_dir)
LOGGER.debug("SQL Script folder: %s", self._scriptdir)
def _get_cmd(self):
"""Returns command to start MySQL server
Returns list.
"""
cmd = [
os.path.join(self._sbindir, EXEC_MYSQLD),
"--defaults-file={0}".format(self._option_file),
]
if os.name == 'nt':
cmd.append('--standalone')
return cmd
def _get_mysqld_help_info(self, needle):
"""Get information from the mysqld binary help
This is basically a grep. Needle is a regular expression which
will be looked for in each line of the mysqld --help --verbose
output. We return the first match as a list.
"""
cmd = [
os.path.join(self._sbindir, EXEC_MYSQLD),
'--help', '--verbose'
]
prc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=DEVNULL)
help_verbose = prc.communicate()[0]
regex = re.compile(needle)
for help_line in help_verbose.splitlines():
help_line = help_line.decode('utf-8').strip()
match = regex.search(help_line)
if match:
return match.groups()
return []
def _get_version(self):
"""Get the MySQL server version
This method executes mysqld with the --version argument. It parses
the output looking for the version number and returns it as a
tuple with integer values: (major,minor,patch)
Returns a tuple.
"""
cmd = [
os.path.join(self._sbindir, EXEC_MYSQLD),
'--version'
]
prc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=DEVNULL)
verstr = str(prc.communicate()[0])
matches = re.match(r'.*Ver (\d)\.(\d).(\d{1,2}).*', verstr)
if matches:
return tuple([int(v) for v in matches.groups()])
else:
raise MySQLServerError(
'Failed reading version from mysqld --version')
@property
def version(self):
"""Returns the MySQL server version
Returns a tuple.
"""
return self._version
def _start_server(self):
"""Start the MySQL server"""
try:
cmd = self._get_cmd()
self._process = subprocess.Popen(cmd, stdout=DEVNULL,
stderr=DEVNULL)
except (OSError, ValueError) as err:
raise MySQLServerError(err)
def _stop_server(self):
"""Stop the MySQL server"""
if not self._process:
return False
try:
process_terminate(self._process.pid)
except (OSError, ValueError) as err:
raise MySQLServerError(err)
return True
def get_exec(self, exec_name):
"""Find executable in the MySQL directories
Returns the the full path to the executable named exec_name or
None when the executable was not found.
Return str or None.
"""
for location in [self._bindir, self._sbindir]:
exec_path = os.path.join(location, exec_name)
if os.access(exec_path, 0):
return exec_path
return None
class MySQLServer(MySQLServerBase):
"""Class for managing a MySQL server"""
def __init__(self, basedir, topdir, cnf, bind_address, port,
name, datadir=None, tmpdir=None,
unix_socket_folder=None, ssl_folder=None, sharedir=None):
self._cnf = cnf
self._option_file = os.path.join(topdir, 'my.cnf')
self._bind_address = bind_address
self._port = port
self._topdir = topdir
self._basedir = basedir
self._ssldir = ssl_folder or topdir
self._datadir = datadir or os.path.join(topdir, 'data')
self._tmpdir = tmpdir or os.path.join(topdir, 'tmp')
self._name = name
self._unix_socket = os.path.join(unix_socket_folder or self._topdir,
'mysql_cpy_' + name + '.sock')
self._pid_file = os.path.join(topdir,
'mysql_cpy_' + name + '.pid')
self._serverid = port + 100000
self._install = None
self._server = None
self._debug = False
self._sharedir = sharedir
self.client_config = {}
super(MySQLServer, self).__init__(self._basedir,
self._option_file,
sharedir=self._sharedir)
def _create_directories(self):
"""Create directory structure for bootstrapping
Create the directories needed for bootstrapping a MySQL
installation, i.e. 'mysql' directory.
The 'test' database is deliberately not created.
Raises MySQLBootstrapError when something fails.
"""
dirs = [
self._topdir,
os.path.join(self._topdir, 'tmp'),
self._datadir,
os.path.join(self._datadir, 'mysql')
]
for adir in dirs:
LOGGER.debug("Creating directory %s", adir)
os.mkdir(adir)
def _get_bootstrap_cmd(self):
"""Get the command for bootstrapping.
Get the command which will be used for bootstrapping. This is
the full path to the mysqld executable and its arguments.
Returns a list (used with subprocess.Popen)
"""
cmd = [
os.path.join(self._sbindir, EXEC_MYSQLD),
'--no-defaults',
'--bootstrap',
'--basedir=%s' % self._basedir,
'--datadir=%s' % self._datadir,
'--log-warnings=0',
'--max_allowed_packet=8M',
'--default-storage-engine=myisam',
'--net_buffer_length=16K',
'--tmpdir=%s' % self._tmpdir,
'--innodb_log_file_size=1Gb',
]
if self._version[0:2] < (5, 5):
cmd.append('--language={0}/english'.format(self._lc_messages_dir))
else:
cmd.extend([
'--lc-messages-dir={0}'.format(self._lc_messages_dir),
'--lc-messages=en_US'
])
if self._version[0:2] >= (5, 1):
cmd.append('--loose-skip-ndbcluster')
return cmd
def bootstrap(self):
"""Bootstrap a MySQL installation
Bootstrap a MySQL installation using the mysqld executable
and the --bootstrap option. Arguments are defined by reading
the defaults file and options set in the _get_bootstrap_cmd()
method.
Raises MySQLBootstrapError when something fails.
"""
if os.access(self._datadir, 0):
raise MySQLBootstrapError("Datadir exists, can't bootstrap MySQL")
# Order is important
script_files = (
'mysql_system_tables.sql',
'mysql_system_tables_data.sql',
'fill_help_tables.sql',
)
# Extra SQL statements to execute after SQL scripts
extra_sql = [
"CREATE DATABASE myconnpy;"
]
insert = (
"INSERT INTO mysql.user VALUES ('localhost','root'{0},"
"'Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y',"
"'Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y','Y',"
"'Y','Y','Y','Y','Y','','','','',0,0,0,0,"
"@@default_authentication_plugin,'','N',"
"CURRENT_TIMESTAMP,NULL{1});"
)
# MySQL 5.7.5+ creates no user while bootstrapping
if self._version[0:3] >= (5, 7, 6):
# MySQL 5.7.6+ have extra account_locked col and no password col
extra_sql.append(insert.format("", ",'N'"))
elif self._version[0:3] >= (5, 7, 5):
extra_sql.append(insert.format(",''", ""))
insert_localhost = (
"INSERT INTO mysql.user SELECT '127.0.0.1', `User`{0},"
" `Select_priv`, `Insert_priv`, `Update_priv`, `Delete_priv`,"
" `Create_priv`, `Drop_priv`, `Reload_priv`, `Shutdown_priv`,"
" `Process_priv`, `File_priv`, `Grant_priv`, `References_priv`,"
" `Index_priv`, `Alter_priv`, `Show_db_priv`, `Super_priv`,"
" `Create_tmp_table_priv`, `Lock_tables_priv`, `Execute_priv`,"
" `Repl_slave_priv`, `Repl_client_priv`, `Create_view_priv`,"
" `Show_view_priv`, `Create_routine_priv`, "
"`Alter_routine_priv`,"
" `Create_user_priv`, `Event_priv`, `Trigger_priv`, "
"`Create_tablespace_priv`, `ssl_type`, `ssl_cipher`,"
"`x509_issuer`, `x509_subject`, `max_questions`, `max_updates`,"
"`max_connections`, `max_user_connections`, `plugin`,"
"`authentication_string`, `password_expired`,"
"`password_last_changed`, `password_lifetime`{1} FROM mysql.user "
"WHERE `user` = 'root' and `host` = 'localhost';"
)
# MySQL 5.7.4+ only creates root@localhost
if self._version[0:3] >= (5, 7, 6):
extra_sql.append(insert_localhost.format("", ",`account_locked`"))
elif self._version[0:3] >= (5, 7, 4):
extra_sql.append(insert_localhost.format(",`Password`", ""))
bootstrap_log = os.path.join(self._topdir, 'bootstrap.log')
try:
self._create_directories()
cmd = self._get_bootstrap_cmd()
sql = ["USE mysql;"]
for filename in script_files:
full_path = os.path.join(self._scriptdir, filename)
LOGGER.debug("Reading SQL from '%s'", full_path)
with open(full_path, 'r') as fp:
sql.extend([line.strip() for line in fp.readlines()])
sql.extend(extra_sql)
fp_log = open(bootstrap_log, 'w')
prc = subprocess.Popen(cmd, stdin=subprocess.PIPE,
stderr=subprocess.STDOUT, stdout=fp_log)
if sys.version_info[0] == 2:
prc.communicate('\n'.join(sql))
else:
prc.communicate(bytearray('\n'.join(sql), 'utf8'))
fp_log.close()
except OSError as err:
raise MySQLBootstrapError(
"Error bootstrapping MySQL '{name}': {error}".format(
name=self._name, error=str(err)))
with open(bootstrap_log, 'r') as fp:
log_lines = fp.readlines()
for log_line in log_lines:
if '[ERROR]' in log_line:
err_msg = log_line.split('[ERROR]')[1].strip()
raise MySQLBootstrapError(
"Error bootstrapping MySQL '{name}': {error}".format(
name=self._name, error=err_msg))
@property
def name(self):
"""Returns the name of this MySQL server"""
return self._name
@property
def port(self):
"""Return TCP/IP port of the server"""
return self._port
@property
def bind_address(self):
"""Return IP address the server is listening on"""
return self._bind_address
@property
def unix_socket(self):
"""Return the unix socket of the server"""
return self._unix_socket
def start(self):
"""Start a MySQL server"""
if self.check_running():
LOGGER.error("MySQL server '{name}' already running".format(
name=self.name))
return
options = {
'name': self._name,
'basedir': _convert_forward_slash(self._basedir),
'datadir': _convert_forward_slash(self._datadir),
'tmpdir': _convert_forward_slash(self._tmpdir),
'bind_address': self._bind_address,
'port': self._port,
'unix_socket': _convert_forward_slash(self._unix_socket),
'ssl_dir': _convert_forward_slash(self._ssldir),
'pid_file': _convert_forward_slash(self._pid_file),
'serverid': self._serverid,
'lc_messages_dir': _convert_forward_slash(
self._lc_messages_dir),
}
try:
fp = open(self._option_file, 'w')
fp.write(self._cnf.format(**options))
fp.close()
self._start_server()
for i in range(10):
if self.check_running():
break
time.sleep(5)
except MySQLServerError as err:
if self._debug is True:
raise
LOGGER.error("Failed starting MySQL server "
"'{name}': {error}".format(name=self.name,
error=str(err)))
sys.exit(1)
else:
pid = get_pid(self._pid_file)
if not pid:
LOGGER.error("Failed getting PID of MySQL server "
"'{name}' (file {pid_file}".format(
name=self._name, pid_file=self._pid_file))
sys.exit(1)
LOGGER.debug("MySQL server started '{name}' "
"(pid={pid})".format(pid=pid, name=self._name))
def stop(self):
"""Stop the MySQL server
Stop the MySQL server and returns whether it was successful or not.
This method stops the process and exits when it failed to stop the
server due to an error. When the process was killed, but it the
process is still found to be running, False is returned. When
the server was stopped successfully, True is returned.
Raises MySQLServerError or OSError when debug is enabled.
Returns True or False.
"""
pid = get_pid(self._pid_file)
if not pid:
return
try:
if not self._stop_server():
process_terminate(pid)
except (MySQLServerError, OSError) as err:
if self._debug is True:
raise
LOGGER.error("Failed stopping MySQL server '{name}': "
"{error}".format(error=str(err), name=self._name))
sys.exit(1)
else:
time.sleep(3)
if self.check_running(pid):
LOGGER.debug("MySQL server stopped '{name}' "
"(pid={pid})".format(pid=pid, name=self._name))
return True
return False
def remove(self):
"""Remove the topdir of the MySQL server"""
if not os.path.exists(self._topdir) or self.check_running():
return
try:
rmtree(self._topdir)
except OSError as err:
LOGGER.debug("Failed removing %s: %s", self._topdir, err)
if self._debug is True:
raise
else:
LOGGER.info("Removed {folder}".format(folder=self._topdir))
def check_running(self, pid=None):
"""Check if MySQL server is running
Check if the MySQL server is running using the given pid, or when
not specified, using the PID found in the PID file.
Returns True or False.
"""
pid = pid or get_pid(self._pid_file)
if pid:
LOGGER.debug("Got PID %d", pid)
return process_running(pid)
return False
def wait_up(self, tries=10, delay=1):
"""Wait until the MySQL server is up
This method can be used to wait until the MySQL server is started.
True is returned when the MySQL server is up, False otherwise.
Return True or False.
"""
running = self.check_running()
while not running:
if tries == 0:
break
time.sleep(delay)
running = self.check_running()
tries -= 1
return running
def wait_down(self, tries=10, delay=1):
"""Wait until the MySQL server is down
This method can be used to wait until the MySQL server has stopped.
True is returned when the MySQL server is down, False otherwise.
Return True or False.
"""
running = self.check_running()
while running:
if tries == 0:
break
time.sleep(delay)
running = self.check_running()
tries -= 1
return not running
class DummyMySQLRequestHandler(BaseRequestHandler):
def __init__(self, request, client_address, server):
super(DummyMySQLRequestHandler, self).__init__(request, client_address,
server)
def read_packet(self):
"""Read a MySQL packet from the socket.
:return: Tuple with type and payload of packet.
:rtype: tuple
"""
header = bytearray(self.request.recv(4))
if not header:
return
length = struct.unpack('<I', header[0:3] + '\x00')[0]
self._curr_pktnr = struct.unpack('B', header[-1])[0]
data = self.request.recv(length)
return header + data
def handle(self):
if self.server.sock_error:
raise socket.error(self.server.socket_error)
res = self._server_replies[0:bufsize]
self._server_replies = self._server_replies[bufsize:]
return res
class DummyMySQLServer(ThreadingMixIn, TCPServer):
"""Class accepting connections for testing MySQL connections"""
def __init__(self, *args, **kwargs):
TCPServer.__init__(self, *args, **kwargs)
self._server_replies = bytearray(b'')
self._client_sends = []
def finish_request(self, request, client_address):
"""Finish one request by instantiating RequestHandlerClass."""
self.RequestHandlerClass(request, client_address, self)
def raise_socket_error(self, err=errno.EPERM):
self.socket_error = err
def add_packet(self, packet):
self._server_replies += packet
def add_packets(self, packets):
for packet in packets:
self._server_replies += packet
def reset(self):
self._raise_socket_error = 0
self._server_replies = bytearray(b'')
self._client_sends = []
def get_address(self):
return 'dummy'

View File

@@ -1,67 +0,0 @@
# -*- coding: utf-8 -*-
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2014, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Incur., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
from functools import wraps
import types
class SkipTest(Exception):
"""Exception compatible with SkipTest of Python v2.7 and later"""
def _id(obj):
"""Function defined in unittest.case which is needed for decorators"""
return obj
def test_skip(reason):
"""Skip test
This decorator is used by Python v2.6 code to keep compatible with
Python v2.7 (and later) unittest.skip.
"""
def decorator(test):
if not isinstance(test, (type, types.ClassType)):
@wraps(test)
def wrapper(*args, **kwargs):
raise SkipTest(reason)
test = wrapper
test.__unittest_skip__ = True
test.__unittest_skip_why__ = reason
return test
return decorator
def test_skip_if(condition, reason):
"""Skip test if condition is true
This decorator is used by Python v2.6 code to keep compatible with
Python v2.7 (and later) unittest.skipIf.
"""
if condition:
return test_skip(reason)
return _id

View File

@@ -1,284 +0,0 @@
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2014, 2016, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unittests for mysql.connector.abstracts
"""
from decimal import Decimal
from operator import attrgetter
import unittest
import tests
from tests import PY2, foreach_cnx
from mysql.connector.connection import MySQLConnection
from mysql.connector.constants import RefreshOption
from mysql.connector import errors
try:
from mysql.connector.connection_cext import CMySQLConnection
except ImportError:
# Test without C Extension
CMySQLConnection = None
class ConnectionSubclasses(tests.MySQLConnectorTests):
"""Tests for any subclass of MySQLConnectionAbstract
"""
def asEq(self, exp, *cases):
for case in cases:
self.assertEqual(exp, case)
@foreach_cnx()
def test_properties_getter(self):
properties = [
(self.config['user'], 'user'),
(self.config['host'], 'server_host'),
(self.config['port'], 'server_port'),
(self.config['unix_socket'], 'unix_socket'),
(self.config['database'], 'database')
]
for exp, property in properties:
f = attrgetter(property)
self.asEq(exp, f(self.cnx))
@foreach_cnx()
def test_time_zone(self):
orig = self.cnx.info_query("SELECT @@session.time_zone")[0]
self.assertEqual(orig, self.cnx.time_zone)
self.cnx.time_zone = "+02:00"
self.assertEqual("+02:00", self.cnx.time_zone)
@foreach_cnx()
def test_sql_mode(self):
orig = self.cnx.info_query("SELECT @@session.sql_mode")[0]
self.assertEqual(orig, self.cnx.sql_mode)
try:
self.cnx.sql_mode = 'SPAM'
except errors.ProgrammingError:
pass # excepted
else:
self.fail("ProgrammingError not raises")
# Set SQL Mode to a list of modes
if tests.MYSQL_VERSION[0:3] < (5, 7, 4):
exp = ('STRICT_TRANS_TABLES,STRICT_ALL_TABLES,NO_ZERO_IN_DATE,'
'NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,TRADITIONAL,'
'NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION')
else:
exp = ('STRICT_TRANS_TABLES,STRICT_ALL_TABLES,TRADITIONAL,'
'NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION')
try:
self.cnx.sql_mode = exp
except errors.Error as err:
self.fail("Failed setting SQL Mode with multiple "
"modes: {0}".format(str(err)))
self.assertEqual(exp, self.cnx._sql_mode)
# SQL Modes must be empty
self.cnx.sql_mode = ''
self.assertEqual('', self.cnx.sql_mode)
# Set SQL Mode and check
sql_mode = exp = 'STRICT_ALL_TABLES'
self.cnx.sql_mode = sql_mode
self.assertEqual(exp, self.cnx.sql_mode)
# Unset the SQL Mode again
self.cnx.sql_mode = ''
self.assertEqual('', self.cnx.sql_mode)
@foreach_cnx()
def test_in_transaction(self):
self.cnx.cmd_query('START TRANSACTION')
self.assertTrue(self.cnx.in_transaction)
self.cnx.cmd_query('ROLLBACK')
self.assertFalse(self.cnx.in_transaction)
# AUTO_COMMIT turned ON
self.cnx.autocommit = True
self.assertFalse(self.cnx.in_transaction)
self.cnx.cmd_query('START TRANSACTION')
self.assertTrue(self.cnx.in_transaction)
@foreach_cnx()
def test_disconnect(self):
self.cnx.disconnect()
self.assertFalse(self.cnx.is_connected())
@foreach_cnx()
def test_is_connected(self):
"""Check connection to MySQL Server"""
self.assertEqual(True, self.cnx.is_connected())
self.cnx.disconnect()
self.assertEqual(False, self.cnx.is_connected())
@foreach_cnx()
def test_info_query(self):
queries = [
("SELECT 1",
(1,)),
("SELECT 'ham', 'spam'",
((u'ham', u'spam')))
]
for query, exp in queries:
self.assertEqual(exp, self.cnx.info_query(query))
@foreach_cnx()
def test_cmd_init_db(self):
self.assertRaises(errors.ProgrammingError,
self.cnx.cmd_init_db, 'unknown_database')
self.cnx.cmd_init_db(u'INFORMATION_SCHEMA')
self.assertEqual('INFORMATION_SCHEMA', self.cnx.database.upper())
self.cnx.cmd_init_db('mysql')
self.assertEqual(u'mysql', self.cnx.database)
self.cnx.cmd_init_db('myconnpy')
self.assertEqual(u'myconnpy', self.cnx.database)
@foreach_cnx()
def test_reset_session(self):
exp = [True, u'STRICT_ALL_TABLES', u'-09:00', 33]
self.cnx.autocommit = exp[0]
self.cnx.sql_mode = exp[1]
self.cnx.time_zone = exp[2]
self.cnx.set_charset_collation(exp[3])
user_variables = {'ham': '1', 'spam': '2'}
session_variables = {'wait_timeout': 100000}
self.cnx.reset_session(user_variables, session_variables)
self.assertEqual(exp, [self.cnx.autocommit, self.cnx.sql_mode,
self.cnx.time_zone, self.cnx._charset_id])
exp_user_variables = {'ham': '1', 'spam': '2'}
exp_session_variables = {'wait_timeout': 100000}
for key, value in exp_user_variables.items():
row = self.cnx.info_query("SELECT @{0}".format(key))
self.assertEqual(value, row[0])
for key, value in exp_session_variables.items():
row = self.cnx.info_query("SELECT @@session.{0}".format(key))
self.assertEqual(value, row[0])
@unittest.skipIf(tests.MYSQL_VERSION > (5, 7, 10),
"As of MySQL 5.7.11, mysql_refresh() is deprecated")
@foreach_cnx()
def test_cmd_refresh(self):
refresh = RefreshOption.LOG | RefreshOption.THREADS
exp = {'insert_id': 0, 'affected_rows': 0,
'field_count': 0, 'warning_count': 0,
'status_flag': 0}
self.assertEqual(exp, self.cnx.cmd_refresh(refresh))
query = "SHOW GLOBAL STATUS LIKE 'Uptime_since_flush_status'"
pre_flush = int(self.cnx.info_query(query)[1])
self.cnx.cmd_refresh(RefreshOption.STATUS)
post_flush = int(self.cnx.info_query(query)[1])
self.assertTrue(post_flush <= pre_flush)
@foreach_cnx()
def test_cmd_quit(self):
self.cnx.cmd_quit()
self.assertFalse(self.cnx.is_connected())
@foreach_cnx()
def test_cmd_shutdown(self):
server = tests.MYSQL_SERVERS[0]
# We make sure the connection is re-established.
self.cnx = self.cnx.__class__(**self.config)
self.cnx.cmd_shutdown()
if not server.wait_down():
self.fail("[{0}] ".format(self.cnx.__class__.__name__) +
"MySQL not shut down after cmd_shutdown()")
self.assertRaises(errors.Error, self.cnx.cmd_shutdown)
server.start()
if not server.wait_up():
self.fail("Failed restarting MySQL server after test")
@foreach_cnx()
def test_cmd_statistics(self):
exp = {
'Uptime': int,
'Open tables': int,
'Queries per second avg': Decimal,
'Slow queries': int,
'Threads': int,
'Questions': int,
'Flush tables': int,
'Opens': int
}
stat = self.cnx.cmd_statistics()
self.assertEqual(len(exp), len(stat))
for key, type_ in exp.items():
self.assertTrue(key in stat)
self.assertTrue(isinstance(stat[key], type_))
@foreach_cnx()
def test_cmd_process_info(self):
self.assertRaises(errors.NotSupportedError,
self.cnx.cmd_process_info)
@foreach_cnx()
def test_cmd_process_kill(self):
other_cnx = self.cnx.__class__(**self.config)
pid = other_cnx.connection_id
self.cnx.cmd_process_kill(pid)
self.assertFalse(other_cnx.is_connected())
@foreach_cnx()
def test_start_transaction(self):
self.cnx.start_transaction()
self.assertTrue(self.cnx.in_transaction)
self.cnx.rollback()
self.cnx.start_transaction(consistent_snapshot=True)
self.assertTrue(self.cnx.in_transaction)
self.assertRaises(errors.ProgrammingError,
self.cnx.start_transaction)
self.cnx.rollback()
levels = ['READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ',
'SERIALIZABLE',
'READ-UNCOMMITTED', 'READ-COMMITTED', 'REPEATABLE-READ',
'SERIALIZABLE']
for level in levels:
level = level.replace(' ', '-')
self.cnx.start_transaction(isolation_level=level)
self.assertTrue(self.cnx.in_transaction)
self.cnx.rollback()
self.assertRaises(ValueError,
self.cnx.start_transaction,
isolation_level='spam')

View File

@@ -1,195 +0,0 @@
# -*- coding: utf-8 -*-
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2014, 2015, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Incur., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Test module for authentication
"""
import inspect
import sys
import mysql.connector
from mysql.connector import authentication
import tests
from . import PY2
_STANDARD_PLUGINS = (
'mysql_native_password',
'mysql_clear_password',
'sha256_password',
)
class AuthenticationModuleTests(tests.MySQLConnectorTests):
"""Tests globals and functions of the authentication module"""
def test_get_auth_plugin(self):
self.assertRaises(mysql.connector.NotSupportedError,
authentication.get_auth_plugin, 'spam')
self.assertRaises(mysql.connector.NotSupportedError,
authentication.get_auth_plugin, '')
# Test using standard plugins
plugin_classes = {}
for name, obj in inspect.getmembers(authentication):
if inspect.isclass(obj) and hasattr(obj, 'plugin_name'):
if obj.plugin_name:
plugin_classes[obj.plugin_name] = obj
for plugin_name in _STANDARD_PLUGINS:
self.assertEqual(plugin_classes[plugin_name],
authentication.get_auth_plugin(plugin_name),
"Failed getting class for {0}".format(plugin_name))
class BaseAuthPluginTests(tests.MySQLConnectorTests):
"""Tests authentication.BaseAuthPlugin"""
def test_class(self):
self.assertEqual('', authentication.BaseAuthPlugin.plugin_name)
self.assertEqual(False, authentication.BaseAuthPlugin.requires_ssl)
def test___init__(self):
base = authentication.BaseAuthPlugin('ham')
self.assertEqual('ham', base._auth_data)
self.assertEqual(None, base._username)
self.assertEqual(None, base._password)
self.assertEqual(None, base._database)
self.assertEqual(False, base._ssl_enabled)
base = authentication.BaseAuthPlugin(
'spam', username='ham', password='secret',
database='test', ssl_enabled=True)
self.assertEqual('spam', base._auth_data)
self.assertEqual('ham', base._username)
self.assertEqual('secret', base._password)
self.assertEqual('test', base._database)
self.assertEqual(True, base._ssl_enabled)
def test_prepare_password(self):
base = authentication.BaseAuthPlugin('ham')
self.assertRaises(NotImplementedError, base.prepare_password)
def test_auth_response(self):
base = authentication.BaseAuthPlugin('ham')
self.assertRaises(NotImplementedError, base.auth_response)
base.requires_ssl = True
self.assertRaises(mysql.connector.InterfaceError, base.auth_response)
class MySQLNativePasswordAuthPluginTests(tests.MySQLConnectorTests):
"""Tests authentication.MySQLNativePasswordAuthPlugin"""
def setUp(self):
self.plugin_class = authentication.MySQLNativePasswordAuthPlugin
def test_class(self):
self.assertEqual('mysql_native_password', self.plugin_class.plugin_name)
self.assertEqual(False, self.plugin_class.requires_ssl)
def test_prepare_password(self):
auth_plugin = self.plugin_class(None, password='spam')
self.assertRaises(mysql.connector.InterfaceError,
auth_plugin.prepare_password)
auth_plugin = self.plugin_class(123456, password='spam') # too long
self.assertRaises(mysql.connector.InterfaceError,
auth_plugin.prepare_password)
if PY2:
empty = ''
auth_data = (
'\x3b\x55\x78\x7d\x2c\x5f\x7c\x72\x49\x52'
'\x3f\x28\x47\x6f\x77\x28\x5f\x28\x46\x69'
)
auth_response = (
'\x3a\x07\x66\xba\xba\x01\xce\xbe\x55\xe6'
'\x29\x88\xaa\xae\xdb\x00\xb3\x4d\x91\x5b'
)
else:
empty = b''
auth_data = (
b'\x3b\x55\x78\x7d\x2c\x5f\x7c\x72\x49\x52'
b'\x3f\x28\x47\x6f\x77\x28\x5f\x28\x46\x69'
)
auth_response = (
b'\x3a\x07\x66\xba\xba\x01\xce\xbe\x55\xe6'
b'\x29\x88\xaa\xae\xdb\x00\xb3\x4d\x91\x5b'
)
auth_plugin = self.plugin_class('\x3f'*20, password=None)
self.assertEqual(empty, auth_plugin.prepare_password())
auth_plugin = self.plugin_class(auth_data, password='spam')
self.assertEqual(auth_response, auth_plugin.prepare_password())
self.assertEqual(auth_response, auth_plugin.auth_response())
class MySQLClearPasswordAuthPluginTests(tests.MySQLConnectorTests):
"""Tests authentication.MySQLClearPasswordAuthPlugin"""
def setUp(self):
self.plugin_class = authentication.MySQLClearPasswordAuthPlugin
def test_class(self):
self.assertEqual('mysql_clear_password', self.plugin_class.plugin_name)
self.assertEqual(True, self.plugin_class.requires_ssl)
def test_prepare_password(self):
if PY2:
exp = 'spam\x00'
else:
exp = b'spam\x00'
auth_plugin = self.plugin_class(None, password='spam', ssl_enabled=True)
self.assertEqual(exp, auth_plugin.prepare_password())
self.assertEqual(exp, auth_plugin.auth_response())
class MySQLSHA256PasswordAuthPluginTests(tests.MySQLConnectorTests):
"""Tests authentication.MySQLSHA256PasswordAuthPlugin"""
def setUp(self):
self.plugin_class = authentication.MySQLSHA256PasswordAuthPlugin
def test_class(self):
self.assertEqual('sha256_password', self.plugin_class.plugin_name)
self.assertEqual(True, self.plugin_class.requires_ssl)
def test_prepare_password(self):
if PY2:
exp = 'spam\x00'
else:
exp = b'spam\x00'
auth_plugin = self.plugin_class(None, password='spam', ssl_enabled=True)
self.assertEqual(exp, auth_plugin.prepare_password())
self.assertEqual(exp, auth_plugin.auth_response())

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,485 +0,0 @@
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2009, 2014, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unittests for mysql.connector.constants
"""
import tests
from mysql.connector import constants, errors
class Helpers(tests.MySQLConnectorTests):
def test_flag_is_set(self):
"""Check if a particular flag/bit is set"""
data = [
1 << 3,
1 << 5,
1 << 7,
]
flags = 0
for flag in data:
flags |= flag
for flag in data:
self.assertTrue(constants.flag_is_set(flag, flags))
self.assertFalse(constants.flag_is_set(1 << 4, flags))
def test_MAX_PACKET_LENGTH(self):
"""Check MAX_PACKET_LENGTH"""
self.assertEqual(16777215, constants.MAX_PACKET_LENGTH)
def test_NET_BUFFER_LENGTH(self):
"""Check NET_BUFFER_LENGTH"""
self.assertEqual(8192, constants.NET_BUFFER_LENGTH)
class FieldTypeTests(tests.MySQLConnectorTests):
desc = {
'DECIMAL': (0x00, 'DECIMAL'),
'TINY': (0x01, 'TINY'),
'SHORT': (0x02, 'SHORT'),
'LONG': (0x03, 'LONG'),
'FLOAT': (0x04, 'FLOAT'),
'DOUBLE': (0x05, 'DOUBLE'),
'NULL': (0x06, 'NULL'),
'TIMESTAMP': (0x07, 'TIMESTAMP'),
'LONGLONG': (0x08, 'LONGLONG'),
'INT24': (0x09, 'INT24'),
'DATE': (0x0a, 'DATE'),
'TIME': (0x0b, 'TIME'),
'DATETIME': (0x0c, 'DATETIME'),
'YEAR': (0x0d, 'YEAR'),
'NEWDATE': (0x0e, 'NEWDATE'),
'VARCHAR': (0x0f, 'VARCHAR'),
'BIT': (0x10, 'BIT'),
'NEWDECIMAL': (0xf6, 'NEWDECIMAL'),
'ENUM': (0xf7, 'ENUM'),
'SET': (0xf8, 'SET'),
'TINY_BLOB': (0xf9, 'TINY_BLOB'),
'MEDIUM_BLOB': (0xfa, 'MEDIUM_BLOB'),
'LONG_BLOB': (0xfb, 'LONG_BLOB'),
'BLOB': (0xfc, 'BLOB'),
'VAR_STRING': (0xfd, 'VAR_STRING'),
'STRING': (0xfe, 'STRING'),
'GEOMETRY': (0xff, 'GEOMETRY'),
}
type_groups = {
'string': [
constants.FieldType.VARCHAR,
constants.FieldType.ENUM,
constants.FieldType.VAR_STRING, constants.FieldType.STRING,
],
'binary': [
constants.FieldType.TINY_BLOB, constants.FieldType.MEDIUM_BLOB,
constants.FieldType.LONG_BLOB, constants.FieldType.BLOB,
],
'number': [
constants.FieldType.DECIMAL, constants.FieldType.NEWDECIMAL,
constants.FieldType.TINY, constants.FieldType.SHORT,
constants.FieldType.LONG,
constants.FieldType.FLOAT, constants.FieldType.DOUBLE,
constants.FieldType.LONGLONG, constants.FieldType.INT24,
constants.FieldType.BIT,
constants.FieldType.YEAR,
],
'datetime': [
constants.FieldType.DATETIME, constants.FieldType.TIMESTAMP,
],
}
def test_attributes(self):
"""Check attributes for FieldType"""
self.assertEqual('FIELD_TYPE_', constants.FieldType.prefix)
for key, value in self.desc.items():
self.assertTrue(key in constants.FieldType.__dict__,
'{0} is not an attribute of FieldType'.format(key))
self.assertEqual(
value[0], constants.FieldType.__dict__[key],
'{0} attribute of FieldType has wrong value'.format(key))
def test_get_desc(self):
"""Get field type by name"""
for key, value in self.desc.items():
exp = value[1]
res = constants.FieldType.get_desc(key)
self.assertEqual(exp, res)
self.assertEqual(None, constants.FieldType.get_desc('FooBar'))
def test_get_info(self):
"""Get field type by id"""
for _, value in self.desc.items():
exp = value[1]
res = constants.FieldType.get_info(value[0])
self.assertEqual(exp, res)
self.assertEqual(None, constants.FieldType.get_info(999999999))
def test_get_string_types(self):
"""DBAPI string types"""
self.assertEqual(self.type_groups['string'],
constants.FieldType.get_string_types())
def test_get_binary_types(self):
"""DBAPI string types"""
self.assertEqual(self.type_groups['binary'],
constants.FieldType.get_binary_types())
def test_get_number_types(self):
"""DBAPI number types"""
self.assertEqual(self.type_groups['number'],
constants.FieldType.get_number_types())
def test_get_timestamp_types(self):
"""DBAPI datetime types"""
self.assertEqual(self.type_groups['datetime'],
constants.FieldType.get_timestamp_types())
class FieldFlagTests(tests.MySQLConnectorTests):
desc = {
'NOT_NULL': (1 << 0, "Field can't be NULL"),
'PRI_KEY': (1 << 1, "Field is part of a primary key"),
'UNIQUE_KEY': (1 << 2, "Field is part of a unique key"),
'MULTIPLE_KEY': (1 << 3, "Field is part of a key"),
'BLOB': (1 << 4, "Field is a blob"),
'UNSIGNED': (1 << 5, "Field is unsigned"),
'ZEROFILL': (1 << 6, "Field is zerofill"),
'BINARY': (1 << 7, "Field is binary "),
'ENUM': (1 << 8, "field is an enum"),
'AUTO_INCREMENT': (1 << 9, "field is a autoincrement field"),
'TIMESTAMP': (1 << 10, "Field is a timestamp"),
'SET': (1 << 11, "field is a set"),
'NO_DEFAULT_VALUE': (1 << 12, "Field doesn't have default value"),
'ON_UPDATE_NOW': (1 << 13, "Field is set to NOW on UPDATE"),
'NUM': (1 << 14, "Field is num (for clients)"),
'PART_KEY': (1 << 15, "Intern; Part of some key"),
'GROUP': (1 << 14, "Intern: Group field"), # Same as NUM
'UNIQUE': (1 << 16, "Intern: Used by sql_yacc"),
'BINCMP': (1 << 17, "Intern: Used by sql_yacc"),
'GET_FIXED_FIELDS': (1 << 18, "Used to get fields in item tree"),
'FIELD_IN_PART_FUNC': (1 << 19, "Field part of partition func"),
'FIELD_IN_ADD_INDEX': (1 << 20, "Intern: Field used in ADD INDEX"),
'FIELD_IS_RENAMED': (1 << 21, "Intern: Field is being renamed"),
}
def test_attributes(self):
"""Check attributes for FieldFlag"""
self.assertEqual('', constants.FieldFlag._prefix)
for key, value in self.desc.items():
self.assertTrue(key in constants.FieldFlag.__dict__,
'{0} is not an attribute of FieldFlag'.format(key))
self.assertEqual(
value[0], constants.FieldFlag.__dict__[key],
'{0} attribute of FieldFlag has wrong value'.format(key))
def test_get_desc(self):
"""Get field flag by name"""
for key, value in self.desc.items():
exp = value[1]
res = constants.FieldFlag.get_desc(key)
self.assertEqual(exp, res)
def test_get_info(self):
"""Get field flag by id"""
for exp, info in self.desc.items():
# Ignore the NUM/GROUP (bug in MySQL source code)
if info[0] == 1 << 14:
break
res = constants.FieldFlag.get_info(info[0])
self.assertEqual(exp, res)
def test_get_bit_info(self):
"""Get names of the set flags"""
data = 0
data |= constants.FieldFlag.BLOB
data |= constants.FieldFlag.BINARY
exp = ['BINARY', 'BLOB'].sort()
self.assertEqual(exp, constants.FieldFlag.get_bit_info(data).sort())
class ClientFlagTests(tests.MySQLConnectorTests):
desc = {
'LONG_PASSWD': (1 << 0, 'New more secure passwords'),
'FOUND_ROWS': (1 << 1, 'Found instead of affected rows'),
'LONG_FLAG': (1 << 2, 'Get all column flags'),
'CONNECT_WITH_DB': (1 << 3, 'One can specify db on connect'),
'NO_SCHEMA': (1 << 4, "Don't allow database.table.column"),
'COMPRESS': (1 << 5, 'Can use compression protocol'),
'ODBC': (1 << 6, 'ODBC client'),
'LOCAL_FILES': (1 << 7, 'Can use LOAD DATA LOCAL'),
'IGNORE_SPACE': (1 << 8, "Ignore spaces before ''"),
'PROTOCOL_41': (1 << 9, 'New 4.1 protocol'),
'INTERACTIVE': (1 << 10, 'This is an interactive client'),
'SSL': (1 << 11, 'Switch to SSL after handshake'),
'IGNORE_SIGPIPE': (1 << 12, 'IGNORE sigpipes'),
'TRANSACTIONS': (1 << 13, 'Client knows about transactions'),
'RESERVED': (1 << 14, 'Old flag for 4.1 protocol'),
'SECURE_CONNECTION': (1 << 15, 'New 4.1 authentication'),
'MULTI_STATEMENTS': (1 << 16, 'Enable/disable multi-stmt support'),
'MULTI_RESULTS': (1 << 17, 'Enable/disable multi-results'),
'SSL_VERIFY_SERVER_CERT': (1 << 30, ''),
'REMEMBER_OPTIONS': (1 << 31, ''),
}
def test_attributes(self):
"""Check attributes for ClientFlag"""
for key, value in self.desc.items():
self.assertTrue(key in constants.ClientFlag.__dict__,
'{0} is not an attribute of FieldFlag'.format(key))
self.assertEqual(
value[0], constants.ClientFlag.__dict__[key],
'{0} attribute of FieldFlag has wrong value'.format(key))
def test_get_desc(self):
"""Get client flag by name"""
for key, value in self.desc.items():
exp = value[1]
res = constants.ClientFlag.get_desc(key)
self.assertEqual(exp, res)
def test_get_info(self):
"""Get client flag by id"""
for exp, info in self.desc.items():
res = constants.ClientFlag.get_info(info[0])
self.assertEqual(exp, res)
def test_get_bit_info(self):
"""Get names of the set flags"""
data = 0
data |= constants.ClientFlag.LONG_FLAG
data |= constants.ClientFlag.LOCAL_FILES
exp = ['LONG_FLAG', 'LOCAL_FILES'].sort()
self.assertEqual(exp, constants.ClientFlag.get_bit_info(data).sort())
def test_get_default(self):
"""Get client flags which are set by default.
"""
data = [
constants.ClientFlag.LONG_PASSWD,
constants.ClientFlag.LONG_FLAG,
constants.ClientFlag.CONNECT_WITH_DB,
constants.ClientFlag.PROTOCOL_41,
constants.ClientFlag.TRANSACTIONS,
constants.ClientFlag.SECURE_CONNECTION,
constants.ClientFlag.MULTI_STATEMENTS,
constants.ClientFlag.MULTI_RESULTS,
constants.ClientFlag.LOCAL_FILES,
]
exp = 0
for option in data:
exp |= option
self.assertEqual(exp, constants.ClientFlag.get_default())
class CharacterSetTests(tests.MySQLConnectorTests):
"""Tests for constants.CharacterSet"""
def test_get_info(self):
"""Get info about charset using MySQL ID"""
exp = ('utf8', 'utf8_general_ci')
data = 33
self.assertEqual(exp, constants.CharacterSet.get_info(data))
exception = errors.ProgrammingError
data = 50000
self.assertRaises(exception, constants.CharacterSet.get_info, data)
def test_get_desc(self):
"""Get info about charset using MySQL ID as string"""
exp = 'utf8/utf8_general_ci'
data = 33
self.assertEqual(exp, constants.CharacterSet.get_desc(data))
exception = errors.ProgrammingError
data = 50000
self.assertRaises(exception, constants.CharacterSet.get_desc, data)
def test_get_default_collation(self):
"""Get default collation for a given Character Set"""
func = constants.CharacterSet.get_default_collation
data = 'sjis'
exp = ('sjis_japanese_ci', data, 13)
self.assertEqual(exp, func(data))
self.assertEqual(exp, func(exp[2]))
exception = errors.ProgrammingError
data = 'foobar'
self.assertRaises(exception, func, data)
def test_get_charset_info(self):
"""Get info about charset by name and collation"""
func = constants.CharacterSet.get_charset_info
exp = (209, 'utf8', 'utf8_esperanto_ci')
data = exp[1:]
self.assertEqual(exp, func(data[0], data[1]))
self.assertEqual(exp, func(None, data[1]))
self.assertEqual(exp, func(exp[0]))
self.assertRaises(errors.ProgrammingError,
constants.CharacterSet.get_charset_info, 666)
self.assertRaises(
errors.ProgrammingError,
constants.CharacterSet.get_charset_info, charset='utf8',
collation='utf8_spam_ci')
self.assertRaises(
errors.ProgrammingError,
constants.CharacterSet.get_charset_info,
collation='utf8_spam_ci')
def test_get_supported(self):
"""Get list of all supported character sets"""
exp = (
'big5', 'latin2', 'dec8', 'cp850', 'latin1', 'hp8', 'koi8r',
'swe7', 'ascii', 'ujis', 'sjis', 'cp1251', 'hebrew', 'tis620',
'euckr', 'latin7', 'koi8u', 'gb2312', 'greek', 'cp1250', 'gbk',
'cp1257', 'latin5', 'armscii8', 'utf8', 'ucs2', 'cp866', 'keybcs2',
'macce', 'macroman', 'cp852', 'utf8mb4','utf16', 'utf16le',
'cp1256', 'utf32', 'binary', 'geostd8', 'cp932', 'eucjpms',
'gb18030',
)
self.assertEqual(exp, constants.CharacterSet.get_supported())
class SQLModesTests(tests.MySQLConnectorTests):
modes = (
'REAL_AS_FLOAT',
'PIPES_AS_CONCAT',
'ANSI_QUOTES',
'IGNORE_SPACE',
'NOT_USED',
'ONLY_FULL_GROUP_BY',
'NO_UNSIGNED_SUBTRACTION',
'NO_DIR_IN_CREATE',
'POSTGRESQL',
'ORACLE',
'MSSQL',
'DB2',
'MAXDB',
'NO_KEY_OPTIONS',
'NO_TABLE_OPTIONS',
'NO_FIELD_OPTIONS',
'MYSQL323',
'MYSQL40',
'ANSI',
'NO_AUTO_VALUE_ON_ZERO',
'NO_BACKSLASH_ESCAPES',
'STRICT_TRANS_TABLES',
'STRICT_ALL_TABLES',
'NO_ZERO_IN_DATE',
'NO_ZERO_DATE',
'INVALID_DATES',
'ERROR_FOR_DIVISION_BY_ZERO',
'TRADITIONAL',
'NO_AUTO_CREATE_USER',
'HIGH_NOT_PRECEDENCE',
'NO_ENGINE_SUBSTITUTION',
'PAD_CHAR_TO_FULL_LENGTH',
)
def test_get_info(self):
for mode in SQLModesTests.modes:
self.assertEqual(mode, getattr(constants.SQLMode, mode),
'Wrong info for SQL Mode {0}'.format(mode))
def test_get_full_info(self):
modes = tuple(sorted(SQLModesTests.modes))
self.assertEqual(modes,
constants.SQLMode.get_full_info())
class ShutdownTypeTests(tests.MySQLConnectorTests):
"""Test COM_SHUTDOWN types"""
desc = {
'SHUTDOWN_DEFAULT': (
0,
"defaults to SHUTDOWN_WAIT_ALL_BUFFERS"
),
'SHUTDOWN_WAIT_CONNECTIONS': (
1,
"wait for existing connections to finish"
),
'SHUTDOWN_WAIT_TRANSACTIONS': (
2,
"wait for existing trans to finish"
),
'SHUTDOWN_WAIT_UPDATES': (
8,
"wait for existing updates to finish"
),
'SHUTDOWN_WAIT_ALL_BUFFERS': (
16,
"flush InnoDB and other storage engine buffers"
),
'SHUTDOWN_WAIT_CRITICAL_BUFFERS': (
17,
"don't flush InnoDB buffers, flush other storage engines' buffers"
),
'KILL_QUERY': (
254, "(no description)"
),
'KILL_CONNECTION': (
255, "(no description)"
),
}
def test_attributes(self):
"""Check attributes for FieldType"""
self.assertEqual('', constants.ShutdownType.prefix)
for key, value in self.desc.items():
self.assertTrue(key in constants.ShutdownType.__dict__,
'{0} is not an attribute of FieldType'.format(key))
self.assertEqual(
value[0], constants.ShutdownType.__dict__[key],
'{0} attribute of ShutdownType has wrong value'.format(key))
def test_get_desc(self):
"""Get field flag by name"""
for key, value in self.desc.items():
exp = value[1]
res = constants.ShutdownType.get_desc(key)
self.assertEqual(exp, res)

View File

@@ -1,514 +0,0 @@
# -*- coding: utf-8 -*-
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2009, 2016, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unittests for mysql.connector.conversion
"""
from decimal import Decimal
import datetime
import time
import uuid
import tests
from mysql.connector import conversion, constants
from mysql.connector.catch23 import PY2
class MySQLConverterBaseTests(tests.MySQLConnectorTests):
def test_init(self):
cnv = conversion.MySQLConverterBase()
self.assertEqual('utf8', cnv.charset)
self.assertEqual(True, cnv.use_unicode)
def test_init2(self):
cnv = conversion.MySQLConverterBase(charset='latin1',
use_unicode=False)
self.assertEqual('latin1', cnv.charset)
self.assertEqual(False, cnv.use_unicode)
def test_set_charset(self):
cnv = conversion.MySQLConverterBase()
cnv.set_charset('latin2')
self.assertEqual('latin2', cnv.charset)
def test_set_useunicode(self):
cnv = conversion.MySQLConverterBase()
cnv.set_unicode(False)
self.assertEqual(False, cnv.use_unicode)
def test_to_mysql(self):
cnv = conversion.MySQLConverterBase()
self.assertEqual('a value', cnv.to_mysql('a value'))
def test_to_python(self):
cnv = conversion.MySQLConverterBase()
self.assertEqual('a value', cnv.to_python('nevermind', 'a value'))
def test_escape(self):
cnv = conversion.MySQLConverterBase()
self.assertEqual("'a value'", cnv.escape("'a value'"))
def test_quote(self):
cnv = conversion.MySQLConverterBase()
self.assertEqual("'a value'", cnv.escape("'a value'"))
class MySQLConverterTests(tests.MySQLConnectorTests):
_to_python_data = [
(b'3.14', ('float', constants.FieldType.FLOAT)),
(b'128', ('int', constants.FieldType.TINY)),
(b'1281288', ('long', constants.FieldType.LONG)),
(b'3.14', ('decimal', constants.FieldType.DECIMAL)),
(b'2008-05-07', ('date', constants.FieldType.DATE)),
(b'45:34:10', ('time', constants.FieldType.TIME)),
(b'2008-05-07 22:34:10',
('datetime', constants.FieldType.DATETIME)),
(b'val1,val2', ('set', constants.FieldType.SET, None,
None, None, None, True, constants.FieldFlag.SET)),
('2008', ('year', constants.FieldType.YEAR)),
(b'\x80\x00\x00\x00', ('bit', constants.FieldType.BIT)),
(b'\xc3\xa4 utf8 string', ('utf8', constants.FieldType.STRING,
None, None, None, None, True, 0)),
]
_to_python_exp = (
float(_to_python_data[0][0]),
int(_to_python_data[1][0]),
int(_to_python_data[2][0]),
Decimal('3.14'),
datetime.date(2008, 5, 7),
datetime.timedelta(hours=45, minutes=34, seconds=10),
datetime.datetime(2008, 5, 7, 22, 34, 10),
set(['val1', 'val2']),
int(_to_python_data[8][0]),
2147483648,
unicode(b'\xc3\xa4 utf8 string', 'utf8') if PY2 \
else str(b'\xc3\xa4 utf8 string', 'utf8')
)
def setUp(self):
self.cnv = conversion.MySQLConverter()
def tearDown(self):
pass
def test_init(self):
pass
def test_escape(self):
"""Making strings ready for MySQL operations"""
data = (
None, # should stay the same
int(128), # should stay the same
int(1281288), # should stay the same
float(3.14), # should stay the same
Decimal('3.14'), # should stay a Decimal
r'back\slash',
'newline\n',
'return\r',
"'single'",
'"double"',
'windows\032',
)
exp = (
None,
128,
1281288,
float(3.14),
Decimal("3.14"),
'back\\\\slash',
'newline\\n',
'return\\r',
"\\'single\\'",
'\\"double\\"',
'windows\\\x1a'
)
res = tuple([self.cnv.escape(v) for v in data])
self.assertTrue(res, exp)
def test_quote(self):
"""Quote values making them ready for MySQL operations."""
data = [
None,
int(128),
int(1281288),
float(3.14),
Decimal('3.14'),
b'string A',
b"string B",
]
exp = (
b'NULL',
b'128',
b'1281288',
repr(float(3.14)) if PY2 else b'3.14',
b'3.14',
b"'string A'",
b"'string B'",
)
res = tuple([self.cnv.quote(value) for value in data])
self.assertEqual(res, exp)
def test_to_mysql(self):
"""Convert Python types to MySQL types using helper method"""
st_now = time.localtime()
data = (
128, # int
1281288, # long
float(3.14), # float
'Strings are sexy',
r'\u82b1',
None,
datetime.datetime(2008, 5, 7, 20, 0o1, 23),
datetime.date(2008, 5, 7),
datetime.time(20, 0o3, 23),
st_now,
datetime.timedelta(hours=40, minutes=30, seconds=12),
Decimal('3.14'),
)
exp = (
data[0],
data[1],
data[2],
self.cnv._str_to_mysql(data[3]),
self.cnv._str_to_mysql(data[4]),
None,
b'2008-05-07 20:01:23',
b'2008-05-07',
b'20:03:23',
time.strftime('%Y-%m-%d %H:%M:%S', st_now).encode('ascii'),
b'40:30:12',
b'3.14',
)
res = tuple([self.cnv.to_mysql(value) for value in data])
self.assertEqual(res, exp)
self.assertRaises(TypeError, self.cnv.to_mysql, uuid.uuid4())
def test__str_to_mysql(self):
"""A Python string becomes bytes."""
data = 'This is a string'
exp = data.encode()
res = self.cnv._str_to_mysql(data)
self.assertEqual(exp, res)
def test__bytes_to_mysql(self):
"""A Python bytes stays bytes."""
data = b'This is a bytes'
exp = data
res = self.cnv._bytes_to_mysql(data)
self.assertEqual(exp, res)
def test__bytearray_to_mysql(self):
"""A Python bytearray becomes bytes."""
data = bytearray(b'This is a bytearray',)
exp = bytes(data)
res = self.cnv._bytearray_to_mysql(data)
self.assertEqual(exp, res)
def test__nonetype_to_mysql(self):
"""Python None stays None for MySQL."""
data = None
res = self.cnv._nonetype_to_mysql(data)
self.assertEqual(data, res)
def test__datetime_to_mysql(self):
"""A datetime.datetime becomes formatted like Y-m-d H:M:S[.f]"""
cases = [
(datetime.datetime(2008, 5, 7, 20, 1, 23),
b'2008-05-07 20:01:23'),
(datetime.datetime(2012, 5, 2, 20, 1, 23, 10101),
b'2012-05-02 20:01:23.010101')
]
for data, exp in cases:
self.assertEqual(exp, self.cnv._datetime_to_mysql(data))
def test__date_to_mysql(self):
"""A datetime.date becomes formatted like Y-m-d"""
data = datetime.date(2008, 5, 7)
res = self.cnv._date_to_mysql(data)
exp = data.strftime('%Y-%m-%d').encode('ascii')
self.assertEqual(exp, res)
def test__time_to_mysql(self):
"""A datetime.time becomes formatted like Y-m-d H:M:S[.f]"""
cases = [
(datetime.time(20, 3, 23), b'20:03:23'),
(datetime.time(20, 3, 23, 10101), b'20:03:23.010101'),
]
for data, exp in cases:
self.assertEqual(exp, self.cnv._time_to_mysql(data))
def test__struct_time_to_mysql(self):
"""A time.struct_time becomes formatted like Y-m-d H:M:S"""
data = time.localtime()
res = self.cnv._struct_time_to_mysql(data)
exp = time.strftime('%Y-%m-%d %H:%M:%S', data).encode('ascii')
self.assertEqual(exp, res)
def test__timedelta_to_mysql(self):
"""A datetime.timedelta becomes format like 'H:M:S[.f]'"""
cases = [
(datetime.timedelta(hours=40, minutes=30, seconds=12),
b'40:30:12'),
(datetime.timedelta(hours=-40, minutes=30, seconds=12),
b'-39:29:48'),
(datetime.timedelta(hours=40, minutes=-1, seconds=12),
b'39:59:12'),
(datetime.timedelta(hours=-40, minutes=60, seconds=12),
b'-38:59:48'),
(datetime.timedelta(hours=40, minutes=30, seconds=12,
microseconds=10101),
b'40:30:12.010101'),
(datetime.timedelta(hours=-40, minutes=30, seconds=12,
microseconds=10101),
b'-39:29:47.989899'),
(datetime.timedelta(hours=40, minutes=-1, seconds=12,
microseconds=10101),
b'39:59:12.010101'),
(datetime.timedelta(hours=-40, minutes=60, seconds=12,
microseconds=10101),
b'-38:59:47.989899'),
]
for i, case in enumerate(cases):
data, exp = case
self.assertEqual(exp, self.cnv._timedelta_to_mysql(data),
"Case {0} failed: {1}; got {2}".format(
i + 1, repr(data),
self.cnv._timedelta_to_mysql(data)))
def test__decimal_to_mysql(self):
"""A decimal.Decimal becomes a string."""
data = Decimal('3.14')
self.assertEqual(b'3.14', self.cnv._decimal_to_mysql(data))
def test_to_python(self):
"""Convert MySQL data to Python types using helper method"""
res = tuple(
[self.cnv.to_python(v[1], v[0]) for v in self._to_python_data])
self.assertEqual(res, tuple(self._to_python_exp))
def test_row_to_python(self):
data = [v[0] for v in self._to_python_data]
description = [v[1] for v in self._to_python_data]
res = self.cnv.row_to_python(data, description)
self.assertEqual(res, self._to_python_exp)
def test__FLOAT_to_python(self):
"""Convert a MySQL FLOAT/DOUBLE to a Python float type"""
data = b'3.14'
exp = float(data)
res = self.cnv._FLOAT_to_python(data)
self.assertEqual(exp, res)
self.assertEqual(self.cnv._FLOAT_to_python,
self.cnv._DOUBLE_to_python)
def test__INT_to_python(self):
"""Convert a MySQL TINY/SHORT/INT24/INT to a Python int type"""
data = b'128'
exp = int(data)
res = self.cnv._INT_to_python(data)
self.assertEqual(exp, res)
self.assertEqual(self.cnv._INT_to_python, self.cnv._TINY_to_python)
self.assertEqual(self.cnv._INT_to_python, self.cnv._SHORT_to_python)
self.assertEqual(self.cnv._INT_to_python, self.cnv._INT24_to_python)
def test__LONG_to_python(self):
"""Convert a MySQL LONG/LONGLONG to a Python long type"""
data = b'1281288'
exp = int(data)
res = self.cnv._LONG_to_python(data)
self.assertEqual(exp, res)
self.assertEqual(self.cnv._LONG_to_python,
self.cnv._LONGLONG_to_python)
def test__DECIMAL_to_python(self):
"""Convert a MySQL DECIMAL to a Python decimal.Decimal type"""
data = b'3.14'
exp = Decimal('3.14')
res = self.cnv._DECIMAL_to_python(data)
self.assertEqual(exp, res)
self.assertEqual(self.cnv._DECIMAL_to_python,
self.cnv._NEWDECIMAL_to_python)
def test__BIT_to_python(self):
"""Convert a MySQL BIT to Python int"""
data = [
b'\x80',
b'\x80\x00',
b'\x80\x00\x00',
b'\x80\x00\x00\x00',
b'\x80\x00\x00\x00\x00',
b'\x80\x00\x00\x00\x00\x00',
b'\x80\x00\x00\x00\x00\x00\x00',
b'\x80\x00\x00\x00\x00\x00\x00\x00',
]
exp = [128, 32768, 8388608, 2147483648, 549755813888,
140737488355328, 36028797018963968, 9223372036854775808]
for i, buf in enumerate(data):
self.assertEqual(self.cnv._BIT_to_python(buf), exp[i])
def test__DATE_to_python(self):
"""Convert a MySQL DATE to a Python datetime.date type"""
data = b'2008-05-07'
exp = datetime.date(2008, 5, 7)
res = self.cnv._DATE_to_python(data)
self.assertEqual(exp, res)
res = self.cnv._DATE_to_python(b'0000-00-00')
self.assertEqual(None, res)
res = self.cnv._DATE_to_python(b'1000-00-00')
self.assertEqual(None, res)
def test__TIME_to_python(self):
"""Convert a MySQL TIME to a Python datetime.time type"""
cases = [
(b'45:34:10',
datetime.timedelta(hours=45, minutes=34, seconds=10)),
(b'-45:34:10',
datetime.timedelta(-2, 8750)),
(b'45:34:10.010101',
datetime.timedelta(hours=45, minutes=34, seconds=10,
microseconds=10101)),
(b'-45:34:10.010101',
datetime.timedelta(-2, 8749, 989899)),
]
for i, case in enumerate(cases):
data, exp = case
self.assertEqual(exp, self.cnv._TIME_to_python(data),
"Case {0} failed: {1}; got {2}".format(
i + 1, repr(data),
repr(self.cnv._TIME_to_python(data))))
def test__DATETIME_to_python(self):
"""Convert a MySQL DATETIME to a Python datetime.datetime type"""
cases = [
(b'2008-05-07 22:34:10',
datetime.datetime(2008, 5, 7, 22, 34, 10)),
(b'2008-05-07 22:34:10.010101',
datetime.datetime(2008, 5, 7, 22, 34, 10, 10101)),
(b'0000-00-00 00:00:00', None),
(b'1000-00-00 00:00:00', None),
]
for data, exp in cases:
self.assertEqual(exp, self.cnv._DATETIME_to_python(data))
def test__YEAR_to_python(self):
"""Convert a MySQL YEAR to Python int"""
data = '2008'
exp = 2008
self.assertEqual(exp, self.cnv._YEAR_to_python(data))
data = 'foobar'
self.assertRaises(ValueError, self.cnv._YEAR_to_python, data)
def test__SET_to_python(self):
"""Convert a MySQL SET type to a Python sequence
This actually calls hte _STRING_to_python() method since a SET is
returned as string by MySQL. However, the description of the field
has in it's field flags that the string is a SET.
"""
data = b'val1,val2'
exp = set(['val1', 'val2'])
desc = ('foo', constants.FieldType.STRING,
2, 3, 4, 5, 6, constants.FieldFlag.SET)
res = self.cnv._STRING_to_python(data, desc)
self.assertEqual(exp, res)
def test__STRING_to_python_utf8(self):
"""Convert a UTF-8 MySQL STRING/VAR_STRING to a Python Unicode type"""
self.cnv.set_charset('utf8') # default
data = b'\xc3\xa4 utf8 string'
exp = data.decode('utf-8')
res = self.cnv._STRING_to_python(data)
self.assertEqual(exp, res)
def test__STRING_to_python_latin1(self):
"""Convert a ISO-8859-1 MySQL STRING/VAR_STRING to a Python str"""
self.cnv.set_charset('latin1')
self.cnv.set_unicode(False)
data = b'\xe4 latin string'
exp = data
res = self.cnv._STRING_to_python(data)
self.assertEqual(exp, res)
exp = data.decode('latin1')
self.cnv.set_unicode(True)
res = self.cnv._STRING_to_python(data)
self.assertEqual(exp, res)
self.cnv.set_charset('utf8')
self.cnv.set_unicode(True)
def test__STRING_to_python_binary(self):
"""Convert a STRING BINARY to Python bytes type"""
data = b'\x33\xfd\x34\xed'
desc = ('foo', constants.FieldType.STRING,
2, 3, 4, 5, 6, constants.FieldFlag.BINARY)
res = self.cnv._STRING_to_python(data, desc)
self.assertEqual(data, res)
def test__BLOB_to_python_binary(self):
"""Convert a BLOB BINARY to Python bytes type"""
data = b'\x33\xfd\x34\xed'
desc = ('foo', constants.FieldType.BLOB,
2, 3, 4, 5, 6, constants.FieldFlag.BINARY)
res = self.cnv._BLOB_to_python(data, desc)
self.assertEqual(data, res)

File diff suppressed because it is too large Load Diff

View File

@@ -1,343 +0,0 @@
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2014, 2015, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unittests for mysql.connector.django
"""
import datetime
import unittest
import sys
import unittest
import tests
# Load 3rd party _after_ loading tests
try:
from django.conf import settings
except ImportError:
DJANGO_AVAILABLE = False
else:
DJANGO_AVAILABLE = True
# Have to setup Django before loading anything else
if DJANGO_AVAILABLE:
try:
settings.configure()
except RuntimeError as exc:
if not 'already configured' in str(exc):
raise
DBCONFIG = tests.get_mysql_config()
settings.DATABASES = {
'default': {
'ENGINE': 'mysql.connector.django',
'NAME': DBCONFIG['database'],
'USER': 'root',
'PASSWORD': '',
'HOST': DBCONFIG['host'],
'PORT': DBCONFIG['port'],
'TEST_CHARSET': 'utf8',
'TEST_COLLATION': 'utf8_general_ci',
'CONN_MAX_AGE': 0,
'AUTOCOMMIT': True,
},
}
settings.SECRET_KEY = "django_tests_secret_key"
settings.TIME_ZONE = 'UTC'
settings.USE_TZ = False
settings.SOUTH_TESTS_MIGRATE = False
settings.DEBUG = False
TABLES = {}
TABLES['django_t1'] = """
CREATE TABLE {table_name} (
id INT NOT NULL AUTO_INCREMENT,
c1 INT,
c2 VARCHAR(20),
INDEX (c1),
UNIQUE INDEX (c2),
PRIMARY KEY (id)
) ENGINE=InnoDB
"""
TABLES['django_t2'] = """
CREATE TABLE {table_name} (
id INT NOT NULL AUTO_INCREMENT,
id_t1 INT NOT NULL,
INDEX (id_t1),
PRIMARY KEY (id),
FOREIGN KEY (id_t1) REFERENCES django_t1(id) ON DELETE CASCADE
) ENGINE=InnoDB
"""
# Have to load django.db to make importing db backend work for Django < 1.6
import django.db # pylint: disable=W0611
if tests.DJANGO_VERSION >= (1, 6):
if tests.DJANGO_VERSION >= (1, 8):
from django.db.backends.base.introspection import FieldInfo
else:
from django.db.backends import FieldInfo
from django.db.backends.signals import connection_created
from django.utils.safestring import SafeBytes, SafeText
import mysql.connector
if DJANGO_AVAILABLE:
from mysql.connector.django.base import (
DatabaseWrapper, DatabaseOperations, DjangoMySQLConverter)
from mysql.connector.django.introspection import DatabaseIntrospection
@unittest.skipIf(not DJANGO_AVAILABLE, "Django not available")
class DjangoIntrospection(tests.MySQLConnectorTests):
"""Test the Django introspection module"""
cnx = None
introspect = None
def setUp(self):
# Python 2.6 has no setUpClass, we run it here, once.
if sys.version_info < (2, 7) and not self.__class__.cnx:
self.__class__.setUpClass()
@classmethod
def setUpClass(cls):
dbconfig = tests.get_mysql_config()
cls.cnx = DatabaseWrapper(settings.DATABASES['default'])
cls.introspect = DatabaseIntrospection(cls.cnx)
cur = cls.cnx.cursor()
for table_name, sql in TABLES.items():
cur.execute("SET foreign_key_checks = 0")
cur.execute("DROP TABLE IF EXISTS {table_name}".format(
table_name=table_name))
cur.execute(sql.format(table_name=table_name))
cur.execute("SET foreign_key_checks = 1")
@classmethod
def tearDownClass(cls):
cur = cls.cnx.cursor()
cur.execute("SET foreign_key_checks = 0")
for table_name, sql in TABLES.items():
cur.execute("DROP TABLE IF EXISTS {table_name}".format(
table_name=table_name))
cur.execute("SET foreign_key_checks = 1")
def test_get_table_list(self):
cur = self.cnx.cursor()
exp = list(TABLES.keys())
for exp in list(TABLES.keys()):
if sys.version_info < (2, 7):
self.assertTrue(exp in self.introspect.get_table_list(cur))
else:
self.assertIn(exp, self.introspect.get_table_list(cur),
"Table {table_name} not in table list".format(
table_name=exp))
def test_get_table_description(self):
cur = self.cnx.cursor()
if tests.DJANGO_VERSION < (1, 6):
exp = [
('id', 3, None, None, None, None, 0, 16899),
('c1', 3, None, None, None, None, 1, 16392),
('c2', 253, None, 20, None, None, 1, 16388)
]
else:
exp = [
FieldInfo(name='id', type_code=3, display_size=None,
internal_size=None, precision=None, scale=None,
null_ok=0),
FieldInfo(name='c1', type_code=3, display_size=None,
internal_size=None, precision=None, scale=None,
null_ok=1),
FieldInfo(name='c2', type_code=253, display_size=None,
internal_size=20, precision=None, scale=None,
null_ok=1)
]
res = self.introspect.get_table_description(cur, 'django_t1')
self.assertEqual(exp, res)
def test_get_relations(self):
cur = self.cnx.cursor()
exp = {1: (0, 'django_t1')}
self.assertEqual(exp, self.introspect.get_relations(cur, 'django_t2'))
def test_get_key_columns(self):
cur = self.cnx.cursor()
exp = [('id_t1', 'django_t1', 'id')]
self.assertEqual(exp, self.introspect.get_key_columns(cur, 'django_t2'))
def test_get_indexes(self):
cur = self.cnx.cursor()
exp = {
'c1': {'primary_key': False, 'unique': False},
'id': {'primary_key': True, 'unique': True},
'c2': {'primary_key': False, 'unique': True}
}
self.assertEqual(exp, self.introspect.get_indexes(cur, 'django_t1'))
def test_get_primary_key_column(self):
cur = self.cnx.cursor()
res = self.introspect.get_primary_key_column(cur, 'django_t1')
self.assertEqual('id', res)
@unittest.skipIf(not DJANGO_AVAILABLE, "Django not available")
class DjangoDatabaseWrapper(tests.MySQLConnectorTests):
"""Test the Django base.DatabaseWrapper class"""
def setUp(self):
dbconfig = tests.get_mysql_config()
self.conn = mysql.connector.connect(**dbconfig)
self.cnx = DatabaseWrapper(settings.DATABASES['default'])
def test__init__(self):
exp = self.conn.get_server_version()
self.assertEqual(exp, self.cnx.mysql_version)
value = datetime.time(2, 5, 7)
exp = self.conn.converter._time_to_mysql(value)
self.assertEqual(exp, self.cnx.ops.value_to_db_time(value))
self.cnx.connection = None
value = datetime.time(2, 5, 7)
exp = self.conn.converter._time_to_mysql(value)
self.assertEqual(exp, self.cnx.ops.value_to_db_time(value))
def test_signal(self):
from django.db import connection
def conn_setup(*args, **kwargs):
conn = kwargs['connection']
settings.DEBUG = True
cur = conn.cursor()
settings.DEBUG = False
cur.execute("SET @xyz=10")
cur.close()
connection_created.connect(conn_setup)
cursor = connection.cursor()
cursor.execute("SELECT @xyz")
self.assertEqual((10,), cursor.fetchone())
cursor.close()
self.cnx.close()
def count_conn(self, *args, **kwargs):
try:
self.connections += 1
except AttributeError:
self.connection = 1
def test_connections(self):
connection_created.connect(self.count_conn)
self.connections = 0
# Checking if DatabaseWrapper object creates a connection by default
conn = DatabaseWrapper(settings.DATABASES['default'])
dbo = DatabaseOperations(conn)
dbo.value_to_db_time(datetime.time(3, 3, 3))
self.assertEqual(self.connections, 0)
class DjangoDatabaseOperations(tests.MySQLConnectorTests):
"""Test the Django base.DatabaseOperations class"""
def setUp(self):
dbconfig = tests.get_mysql_config()
self.conn = mysql.connector.connect(**dbconfig)
self.cnx = DatabaseWrapper(settings.DATABASES['default'])
self.dbo = DatabaseOperations(self.cnx)
def test_value_to_db_time(self):
self.assertEqual(None, self.dbo.value_to_db_time(None))
value = datetime.time(0, 0, 0)
exp = self.conn.converter._time_to_mysql(value)
self.assertEqual(exp, self.dbo.value_to_db_time(value))
value = datetime.time(2, 5, 7)
exp = self.conn.converter._time_to_mysql(value)
self.assertEqual(exp, self.dbo.value_to_db_time(value))
def test_value_to_db_datetime(self):
self.assertEqual(None, self.dbo.value_to_db_datetime(None))
value = datetime.datetime(1, 1, 1)
exp = self.conn.converter._datetime_to_mysql(value)
self.assertEqual(exp, self.dbo.value_to_db_datetime(value))
value = datetime.datetime(2, 5, 7, 10, 10)
exp = self.conn.converter._datetime_to_mysql(value)
self.assertEqual(exp, self.dbo.value_to_db_datetime(value))
class DjangoMySQLConverterTests(tests.MySQLConnectorTests):
"""Test the Django base.DjangoMySQLConverter class"""
def test__TIME_to_python(self):
value = b'10:11:12'
django_converter = DjangoMySQLConverter()
self.assertEqual(datetime.time(10, 11, 12),
django_converter._TIME_to_python(value, dsc=None))
def test__DATETIME_to_python(self):
value = b'1990-11-12 00:00:00'
django_converter = DjangoMySQLConverter()
self.assertEqual(datetime.datetime(1990, 11, 12, 0, 0, 0),
django_converter._DATETIME_to_python(value, dsc=None))
settings.USE_TZ = True
value = b'0000-00-00 00:00:00'
django_converter = DjangoMySQLConverter()
self.assertEqual(None,
django_converter._DATETIME_to_python(value, dsc=None))
settings.USE_TZ = False
class BugOra20106629(tests.MySQLConnectorTests):
"""CONNECTOR/PYTHON DJANGO BACKEND DOESN'T SUPPORT SAFETEXT"""
def setUp(self):
dbconfig = tests.get_mysql_config()
self.conn = mysql.connector.connect(**dbconfig)
self.cnx = DatabaseWrapper(settings.DATABASES['default'])
self.cur = self.cnx.cursor()
self.tbl = "BugOra20106629"
self.cur.execute("DROP TABLE IF EXISTS {0}".format(self.tbl), ())
self.cur.execute("CREATE TABLE {0}(col1 TEXT, col2 BLOB)".format(self.tbl), ())
def teardown(self):
self.cur.execute("DROP TABLE IF EXISTS {0}".format(self.tbl), ())
def test_safe_string(self):
safe_text = SafeText("dummy & safe data <html> ")
safe_bytes = SafeBytes(b"\x00\x00\x4c\x6e\x67\x39")
self.cur.execute("INSERT INTO {0} VALUES(%s, %s)".format(self.tbl), (safe_text, safe_bytes))
self.cur.execute("SELECT * FROM {0}".format(self.tbl), ())
self.assertEqual(self.cur.fetchall(), [(safe_text, safe_bytes)])

View File

@@ -1,64 +0,0 @@
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2012, 2014, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unittests for mysql.connector.errorcode
"""
from datetime import datetime
import tests
from mysql.connector import errorcode
class ErrorCodeTests(tests.MySQLConnectorTests):
def test__MYSQL_VERSION(self):
minimum = (5, 6, 6)
self.assertTrue(isinstance(errorcode._MYSQL_VERSION, tuple))
self.assertTrue(len(errorcode._MYSQL_VERSION) == 3)
self.assertTrue(errorcode._MYSQL_VERSION >= minimum)
def _check_code(self, code, num):
try:
self.assertEqual(getattr(errorcode, code), num)
except AttributeError as err:
self.fail(err)
def test_server_error_codes(self):
cases = {
'ER_HASHCHK': 1000,
'ER_TRG_INVALID_CREATION_CTX': 1604,
'ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION': 1792,
}
for code, num in cases.items():
self._check_code(code, num)
def test_client_error_codes(self):
cases = {
'CR_UNKNOWN_ERROR': 2000,
'CR_PROBE_SLAVE_STATUS': 2022,
'CR_AUTH_PLUGIN_CANNOT_LOAD': 2059,
}
for code, num in cases.items():
self._check_code(code, num)

View File

@@ -1,235 +0,0 @@
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2012, 2014, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unittests for mysql.connector.errors
"""
import tests
from mysql.connector import errors
class ErrorsTests(tests.MySQLConnectorTests):
def test_custom_error_exception(self):
customfunc = errors.custom_error_exception
self.assertRaises(ValueError, customfunc, 'spam')
self.assertRaises(ValueError, customfunc, 1)
self.assertRaises(ValueError, customfunc, 1, 'spam')
case = (1, errors.InterfaceError)
exp = {1: errors.InterfaceError}
self.assertEqual(exp, customfunc(*case))
exp = case = {1: errors.InterfaceError, 2: errors.ProgrammingError}
self.assertEqual(exp, customfunc(case))
case = {1: errors.InterfaceError, 2: None}
self.assertRaises(ValueError, customfunc, case)
case = {1: errors.InterfaceError, 2: str()}
self.assertRaises(ValueError, customfunc, case)
case = {'1': errors.InterfaceError}
self.assertRaises(ValueError, customfunc, case)
self.assertEqual({}, customfunc({}))
self.assertEqual({}, errors._CUSTOM_ERROR_EXCEPTIONS)
def test_get_mysql_exception(self):
tests = {
errors.ProgrammingError: (
'24', '25', '26', '27', '28', '2A', '2C',
'34', '35', '37', '3C', '3D', '3F', '42'),
errors.DataError: ('02', '21', '22'),
errors.NotSupportedError: ('0A',),
errors.IntegrityError: ('23', 'XA'),
errors.InternalError: ('40', '44'),
errors.OperationalError: ('08', 'HZ', '0K'),
errors.DatabaseError: ('07', '2B', '2D', '2E', '33', 'ZZ', 'HY'),
}
msg = 'Ham'
for exp, errlist in tests.items():
for sqlstate in errlist:
errno = 1000
res = errors.get_mysql_exception(errno, msg, sqlstate)
self.assertTrue(isinstance(res, exp),
"SQLState {0} should be {1}".format(
sqlstate, exp.__name__))
self.assertEqual(sqlstate, res.sqlstate)
self.assertEqual("{0} ({1}): {2}".format(errno, sqlstate, msg),
str(res))
errno = 1064
sqlstate = "42000"
msg = "You have an error in your SQL syntax"
exp = "1064 (42000): You have an error in your SQL syntax"
err = errors.get_mysql_exception(errno, msg, sqlstate)
self.assertEqual(exp, str(err))
# Hardcoded exceptions
self.assertTrue(isinstance(errors._ERROR_EXCEPTIONS, dict))
self.assertTrue(
isinstance(errors.get_mysql_exception(1243, None, None),
errors.ProgrammingError))
# Custom exceptions
errors._CUSTOM_ERROR_EXCEPTIONS[1064] = errors.DatabaseError
self.assertTrue(
isinstance(errors.get_mysql_exception(1064, None, None),
errors.DatabaseError))
errors._CUSTOM_ERROR_EXCEPTIONS = {}
def test_get_exception(self):
ok_packet = bytearray(b'\x07\x00\x00\x01\x00\x01\x00\x00\x00\x01\x00')
err_packet = bytearray(
b'\x47\x00\x00\x02\xff\x15\x04\x23\x32\x38\x30\x30\x30'
b'\x41\x63\x63\x65\x73\x73\x20\x64\x65\x6e\x69\x65\x64'
b'\x20\x66\x6f\x72\x20\x75\x73\x65\x72\x20\x27\x68\x61'
b'\x6d\x27\x40\x27\x6c\x6f\x63\x61\x6c\x68\x6f\x73\x74'
b'\x27\x20\x28\x75\x73\x69\x6e\x67\x20\x70\x61\x73\x73'
b'\x77\x6f\x72\x64\x3a\x20\x59\x45\x53\x29'
)
self.assertTrue(isinstance(errors.get_exception(err_packet),
errors.ProgrammingError))
self.assertRaises(ValueError,
errors.get_exception, ok_packet)
res = errors.get_exception(bytearray(b'\x47\x00\x00\x02\xff\x15'))
self.assertTrue(isinstance(res, errors.InterfaceError))
class ErrorTest(tests.MySQLConnectorTests):
def test___init__(self):
self.assertTrue(issubclass(errors.Error, Exception))
err = errors.Error()
self.assertEqual(-1, err.errno)
self.assertEqual('Unknown error', err.msg)
self.assertEqual(None, err.sqlstate)
msg = 'Ham'
err = errors.Error(msg, errno=1)
self.assertEqual(1, err.errno)
self.assertEqual('1: {0}'.format(msg), err._full_msg)
self.assertEqual(msg, err.msg)
err = errors.Error('Ham', errno=1, sqlstate="SPAM")
self.assertEqual(1, err.errno)
self.assertEqual('1 (SPAM): Ham', err._full_msg)
self.assertEqual('1 (SPAM): Ham', str(err))
err = errors.Error(errno=2000)
self.assertEqual('Unknown MySQL error', err.msg)
self.assertEqual('2000: Unknown MySQL error', err._full_msg)
err = errors.Error(errno=2003, values=('/path/to/ham', 2))
self.assertEqual(
"2003: Can't connect to MySQL server on '/path/to/ham' (2)",
err._full_msg)
self.assertEqual(
"Can't connect to MySQL server on '/path/to/ham' (2)",
err.msg)
err = errors.Error(errno=2001, values=('ham',))
if '(Warning:' in str(err):
self.fail('Found %d in error message.')
err = errors.Error(errno=2003, values=('ham',))
self.assertEqual(
"2003: Can't connect to MySQL server on '%-.100s' (%s) "
"(Warning: not enough arguments for format string)",
err._full_msg)
def test___str__(self):
msg = "Spam"
self.assertEqual("Spam", str(errors.Error(msg)))
self.assertEqual("1: Spam", str(errors.Error(msg, 1)))
self.assertEqual("1 (XYZ): Spam",
str(errors.Error(msg, 1, sqlstate='XYZ')))
class WarningTests(tests.MySQLConnectorTests):
def test___init__(self):
self.assertTrue(issubclass(errors.Warning, Exception))
class InterfaceErrorTests(tests.MySQLConnectorTests):
def test___init__(self):
self.assertTrue(issubclass(errors.InterfaceError, errors.Error))
class DatabaseErrorTests(tests.MySQLConnectorTests):
def test___init__(self):
self.assertTrue(issubclass(errors.DatabaseError, errors.Error))
class InternalErrorTests(tests.MySQLConnectorTests):
def test___init__(self):
self.assertTrue(issubclass(errors.InternalError,
errors.DatabaseError))
class OperationalErrorTests(tests.MySQLConnectorTests):
def test___init__(self):
self.assertTrue(issubclass(errors.OperationalError,
errors.DatabaseError))
class ProgrammingErrorTests(tests.MySQLConnectorTests):
def test___init__(self):
self.assertTrue(issubclass(errors.ProgrammingError,
errors.DatabaseError))
class IntegrityErrorTests(tests.MySQLConnectorTests):
def test___init__(self):
self.assertTrue(issubclass(errors.IntegrityError,
errors.DatabaseError))
class DataErrorTests(tests.MySQLConnectorTests):
def test___init__(self):
self.assertTrue(issubclass(errors.DataError,
errors.DatabaseError))
class NotSupportedErrorTests(tests.MySQLConnectorTests):
def test___init__(self):
self.assertTrue(issubclass(errors.NotSupportedError,
errors.DatabaseError))
class PoolErrorTests(tests.MySQLConnectorTests):
def test___init__(self):
self.assertTrue(issubclass(errors.PoolError, errors.Error))

View File

@@ -1,225 +0,0 @@
# -*- coding: utf-8 -*-
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2009, 2014, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unittests for examples
"""
from hashlib import md5
import sys
import tests
from . import PY2
import mysql.connector
class TestExamples(tests.MySQLConnectorTests):
def setUp(self):
self.config = tests.get_mysql_config()
self.config['use_pure'] = True
self.cnx = mysql.connector.connect(**self.config)
def tearDown(self):
self.cnx.close()
def _exec_main(self, example, exp=None):
try:
result = example.main(self.config)
if not exp:
return result
except Exception as err:
self.fail(err)
md5_result = md5()
output = u'\n'.join(result)
md5_result.update(output.encode('utf-8'))
self.assertEqual(exp, md5_result.hexdigest(),
'Output was not correct')
def test_dates(self):
"""examples/dates.py"""
try:
import examples.dates as example
except Exception as err:
self.fail(err)
output = example.main(self.config)
exp = [' 1 | 1977-06-14 | 1977-06-14 21:10:00 | 21:10:00 |',
' 2 | None | None | 0:00:00 |',
' 3 | None | None | 0:00:00 |']
self.assertEqual(output, exp)
example.DATA.append(('0000-00-00', None, '00:00:00'),)
self.assertRaises(mysql.connector.errors.IntegrityError,
example.main, self.config)
sys.modules.pop('examples.dates', None)
def test_engines(self):
"""examples/engines.py"""
try:
import examples.engines as example
except:
self.fail()
output = self._exec_main(example)
# Can't check output as it might be different per MySQL instance
# We check only if MyISAM is present
found = False
for line in output:
if line.find('MyISAM') > -1:
found = True
break
self.assertTrue(found, 'MyISAM engine not found in output')
sys.modules.pop('examples.engine', None)
def test_inserts(self):
"""examples/inserts.py"""
try:
import examples.inserts as example
except Exception as err:
self.fail(err)
exp = '077dcd0139015c0aa6fb82ed932f053e'
self._exec_main(example, exp)
sys.modules.pop('examples.inserts', None)
def test_transactions(self):
"""examples/transactions.py"""
db = mysql.connector.connect(**self.config)
r = tests.have_engine(db, 'InnoDB')
db.close()
if not r:
return
try:
import examples.transaction as example
except Exception as e:
self.fail(e)
exp = '3bd75261ffeb5624cdd754a43e2fd938'
self._exec_main(example, exp)
sys.modules.pop('examples.transaction', None)
def test_unicode(self):
"""examples/unicode.py"""
try:
import examples.unicode as example
except Exception as e:
self.fail(e)
output = self._exec_main(example)
if PY2:
exp = [u'Unicode string: ¿Habla español?',
u'Unicode string coming from db: ¿Habla español?']
else:
exp = ['Unicode string: ¿Habla español?',
'Unicode string coming from db: ¿Habla español?']
self.assertEqual(output, exp)
sys.modules.pop('examples.unicode', None)
def test_warnings(self):
"""examples/warnings.py"""
try:
import examples.warnings as example
except Exception as e:
self.fail(e)
output = self._exec_main(example)
exp = ["Executing 'SELECT 'abc'+1'",
"1292: Truncated incorrect DOUBLE value: 'abc'"]
self.assertEqual(output, exp, 'Output was not correct')
example.STMT = "SELECT 'abc'"
self.assertRaises(Exception, example.main, self.config)
sys.modules.pop('examples.warnings', None)
def test_multi_resultsets(self):
"""examples/multi_resultsets.py"""
try:
import examples.multi_resultsets as example
except Exception as e:
self.fail(e)
output = self._exec_main(example)
exp = ['Inserted 1 row', 'Number of rows: 1', 'Inserted 2 rows',
'Names in table: Geert Jan Michel']
self.assertEqual(output, exp, 'Output was not correct')
sys.modules.pop('examples.resultsets', None)
def test_microseconds(self):
"""examples/microseconds.py"""
try:
import examples.microseconds as example
except Exception as e:
self.fail(e)
output = self._exec_main(example)
if self.cnx.get_server_version() < (5, 6, 4):
exp = "does not support fractional precision for timestamps."
self.assertTrue(output[0].endswith(exp))
else:
exp = [
' 1 | 1 | 0:00:47.510000 | 2009-06-07 09:15:02.000234',
' 1 | 2 | 0:00:47.020000 | 2009-06-07 09:30:05.102345',
' 1 | 3 | 0:00:47.650000 | 2009-06-07 09:50:23.002300',
' 1 | 4 | 0:00:46.060000 | 2009-06-07 10:30:56.000001',
]
self.assertEqual(output, exp)
sys.modules.pop('examples.microseconds', None)
def test_prepared_statements(self):
"""examples/prepared_statements.py"""
try:
import examples.prepared_statements as example
except Exception as e:
self.fail(e)
output = self._exec_main(example)
exp = [
'Inserted data',
'1 | Geert',
'2 | Jan',
'3 | Michel',
]
self.assertEqual(output, exp, 'Output was not correct')
sys.modules.pop('examples.prepared_statements', None)
class TestExamplesCExt(TestExamples):
def setUp(self):
self.config = tests.get_mysql_config()
self.config['use_pure'] = False
self.cnx = mysql.connector.connect(**self.config)
def tearDown(self):
self.cnx.close()
def test_prepared_statements(self):
pass

View File

@@ -1,658 +0,0 @@
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2013, 2015, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unittests for mysql.connector.fabric
"""
import datetime
from decimal import Decimal
import time
import unittest
import uuid
try:
from xmlrpclib import Fault, ServerProxy
except ImportError:
# Python v3
from xmlrpc.client import Fault, ServerProxy # pylint: disable=F0401
import tests
import mysql.connector
from mysql.connector import fabric, errorcode
from mysql.connector.fabric import connection, balancing
from mysql.connector.catch23 import UNICODE_TYPES, PY2
from mysql.connector.pooling import PooledMySQLConnection
ERR_NO_FABRIC_CONFIG = "Fabric configuration not available"
def wait_for_gtid(cur, gtid):
cur.execute("SELECT WAIT_UNTIL_SQL_THREAD_AFTER_GTIDS(%s, 2)", (gtid,))
cur.fetchall()
class FabricModuleTests(tests.MySQLConnectorTests):
"""Testing mysql.connector.fabric module"""
def test___all___(self):
attrs = [
'MODE_READWRITE',
'MODE_READONLY',
'STATUS_PRIMARY',
'STATUS_SECONDARY',
'SCOPE_GLOBAL',
'SCOPE_LOCAL',
'FabricMySQLServer',
'FabricShard',
'connect',
'Fabric',
'FabricConnection',
'MySQLFabricConnection',
]
for attr in attrs:
try:
getattr(fabric, attr)
except AttributeError:
self.fail("Attribute '{0}' not in fabric.__all__".format(attr))
def test_fabricmyqlserver(self):
attrs = ['uuid', 'group', 'host', 'port', 'mode', 'status', 'weight']
try:
nmdtpl = fabric.FabricMySQLServer(*([''] * len(attrs)))
except TypeError:
self.fail("Fail creating namedtuple FabricMySQLServer")
self.check_namedtuple(nmdtpl, attrs)
def test_fabricshard(self):
attrs = [
'database', 'table', 'column', 'key', 'shard', 'shard_type',
'group', 'global_group'
]
try:
nmdtpl = fabric.FabricShard(*([''] * len(attrs)))
except TypeError:
self.fail("Fail creating namedtuple FabricShard")
self.check_namedtuple(nmdtpl, attrs)
def test_connect(self):
class FakeConnection(object):
def __init__(self, *args, **kwargs):
pass
orig = fabric.MySQLFabricConnection
fabric.MySQLFabricConnection = FakeConnection
self.assertTrue(isinstance(fabric.connect(), FakeConnection))
fabric.MySQLFabricConnection = orig
class ConnectionModuleTests(tests.MySQLConnectorTests):
"""Testing mysql.connector.fabric.connection module"""
def test_module_variables(self):
error_codes = (
errorcode.CR_SERVER_LOST,
errorcode.ER_OPTION_PREVENTS_STATEMENT,
)
self.assertEqual(error_codes, connection.RESET_CACHE_ON_ERROR)
modvars = {
'MYSQL_FABRIC_PORT': {
'xmlrpc': 32274, 'mysql': 32275
},
'DEFAULT_FABRIC_PROTOCOL': 'xmlrpc',
'FABRICS': {},
'_CNX_ATTEMPT_DELAY': 1,
'_CNX_ATTEMPT_MAX': 3,
'_GETCNX_ATTEMPT_DELAY': 1,
'_GETCNX_ATTEMPT_MAX': 3,
'MODE_READONLY': 1,
'MODE_WRITEONLY': 2,
'MODE_READWRITE': 3,
'STATUS_FAULTY': 0,
'STATUS_SPARE': 1,
'STATUS_SECONDARY': 2,
'STATUS_PRIMARY': 3,
'SCOPE_GLOBAL': 'GLOBAL',
'SCOPE_LOCAL': 'LOCAL',
'_SERVER_STATUS_FAULTY': 'FAULTY',
}
for modvar, value in modvars.items():
try:
self.assertEqual(value, getattr(connection, modvar))
except AttributeError:
self.fail("Module variable connection.{0} not found".format(
modvar))
def test_cnx_properties(self):
cnxprops = {
# name: (valid_types, description, default)
'group': ((str,), "Name of group of servers", None),
'key': (tuple([int, str, datetime.datetime,
datetime.date] + list(UNICODE_TYPES)),
"Sharding key", None),
'tables': ((tuple, list), "List of tables in query", None),
'mode': ((int,), "Read-Only, Write-Only or Read-Write",
connection.MODE_READWRITE),
'shard': ((str,), "Identity of the shard for direct connection",
None),
'mapping': ((str,), "", None),
'scope': ((str,), "GLOBAL for accessing Global Group, or LOCAL",
connection.SCOPE_LOCAL),
'attempts': ((int,), "Attempts for getting connection",
connection._CNX_ATTEMPT_MAX),
'attempt_delay': ((int,), "Seconds to wait between each attempt",
connection._CNX_ATTEMPT_DELAY),
}
for prop, desc in cnxprops.items():
try:
self.assertEqual(desc, connection._CNX_PROPERTIES[prop])
except KeyError:
self.fail("Connection property '{0}'' not available".format(
prop))
self.assertEqual(len(cnxprops), len(connection._CNX_PROPERTIES))
def test__fabric_xmlrpc_uri(self):
data = ('example.com', 32274)
exp = 'http://{host}:{port}'.format(host=data[0], port=data[1])
self.assertEqual(exp, connection._fabric_xmlrpc_uri(*data))
def test__fabric_server_uuid(self):
data = ('example.com', 32274)
url = 'http://{host}:{port}'.format(host=data[0], port=data[1])
exp = uuid.uuid3(uuid.NAMESPACE_URL, url)
self.assertEqual(exp, connection._fabric_server_uuid(*data))
def test__validate_ssl_args(self):
func = connection._validate_ssl_args
kwargs = dict(ssl_ca=None, ssl_key=None, ssl_cert=None)
self.assertEqual(None, func(**kwargs))
kwargs = dict(ssl_ca=None, ssl_key='/path/to/key',
ssl_cert=None)
self.assertRaises(AttributeError, func, **kwargs)
kwargs = dict(ssl_ca='/path/to/ca', ssl_key='/path/to/key',
ssl_cert=None)
self.assertRaises(AttributeError, func, **kwargs)
exp = {
'ca': '/path/to/ca',
'key': None,
'cert': None,
}
kwargs = dict(ssl_ca='/path/to/ca', ssl_key=None, ssl_cert=None)
self.assertEqual(exp, func(**kwargs))
exp = {
'ca': '/path/to/ca',
'key': '/path/to/key',
'cert': '/path/to/cert',
}
res = func(ssl_ca=exp['ca'], ssl_cert=exp['cert'], ssl_key=exp['key'])
self.assertEqual(exp, res)
def test_extra_failure_report(self):
func = connection.extra_failure_report
func([])
self.assertEqual([], connection.REPORT_ERRORS_EXTRA)
self.assertRaises(AttributeError, func, 1)
self.assertRaises(AttributeError, func, [1])
exp = [2222]
func(exp)
self.assertEqual(exp, connection.REPORT_ERRORS_EXTRA)
class FabricBalancingBaseScheduling(tests.MySQLConnectorTests):
"""Test fabric.balancing.BaseScheduling"""
def setUp(self):
self.obj = balancing.BaseScheduling()
def test___init__(self):
self.assertEqual([], self.obj._members)
self.assertEqual([], self.obj._ratios)
def test_set_members(self):
self.assertRaises(NotImplementedError, self.obj.set_members, 'spam')
def test_get_next(self):
self.assertRaises(NotImplementedError, self.obj.get_next)
class FabricBalancingWeightedRoundRobin(tests.MySQLConnectorTests):
"""Test fabric.balancing.WeightedRoundRobin"""
def test___init__(self):
balancer = balancing.WeightedRoundRobin()
self.assertEqual([], balancer._members)
self.assertEqual([], balancer._ratios)
self.assertEqual([], balancer._load)
# init with args
class FakeWRR(balancing.WeightedRoundRobin):
def set_members(self, *args):
self.set_members_called = True
balancer = FakeWRR('ham', 'spam')
self.assertTrue(balancer.set_members_called)
def test_members(self):
balancer = balancing.WeightedRoundRobin()
self.assertEqual([], balancer.members)
balancer._members = ['ham']
self.assertEqual(['ham'], balancer.members)
def test_ratios(self):
balancer = balancing.WeightedRoundRobin()
self.assertEqual([], balancer.ratios)
balancer._ratios = ['ham']
self.assertEqual(['ham'], balancer.ratios)
def test_load(self):
balancer = balancing.WeightedRoundRobin()
self.assertEqual([], balancer.load)
balancer._load = ['ham']
self.assertEqual(['ham'], balancer.load)
def test_set_members(self):
balancer = balancing.WeightedRoundRobin()
balancer._members = ['ham']
balancer.set_members()
self.assertEqual([], balancer.members)
servers = [('ham1', 0.2), ('ham2', 0.8)]
balancer.set_members(*servers)
exp = [('ham2', Decimal('0.8')), ('ham1', Decimal('0.2'))]
self.assertEqual(exp, balancer.members)
self.assertEqual([400, 100], balancer.ratios)
self.assertEqual([0, 0], balancer.load)
def test_reset_load(self):
balancer = balancing.WeightedRoundRobin(*[('ham1', 0.2), ('ham2', 0.8)])
balancer._load = [5, 6]
balancer.reset()
self.assertEqual([0, 0], balancer.load)
def test_get_next(self):
servers = [('ham1', 0.2), ('ham2', 0.8)]
balancer = balancing.WeightedRoundRobin(*servers)
self.assertEqual(('ham2', Decimal('0.8')), balancer.get_next())
self.assertEqual([1, 0], balancer.load)
balancer._load = [80, 0]
self.assertEqual(('ham1', Decimal('0.2')), balancer.get_next())
self.assertEqual([80, 1], balancer.load)
balancer._load = [80, 20]
self.assertEqual(('ham2', Decimal('0.8')), balancer.get_next())
self.assertEqual([81, 20], balancer.load)
servers = [('ham1', 0.1), ('ham2', 0.2), ('ham3', 0.7)]
balancer = balancing.WeightedRoundRobin(*servers)
exp_sum = count = 101
while count > 0:
count -= 1
_ = balancer.get_next()
self.assertEqual(exp_sum, sum(balancer.load))
self.assertEqual([34, 34, 33], balancer.load)
servers = [('ham1', 0.2), ('ham2', 0.2), ('ham3', 0.7)]
balancer = balancing.WeightedRoundRobin(*servers)
exp_sum = count = 101
while count > 0:
count -= 1
_ = balancer.get_next()
self.assertEqual(exp_sum, sum(balancer.load))
self.assertEqual([34, 34, 33], balancer.load)
servers = [('ham1', 0.25), ('ham2', 0.25),
('ham3', 0.25), ('ham4', 0.25)]
balancer = balancing.WeightedRoundRobin(*servers)
exp_sum = count = 101
while count > 0:
count -= 1
_ = balancer.get_next()
self.assertEqual(exp_sum, sum(balancer.load))
self.assertEqual([26, 25, 25, 25], balancer.load)
servers = [('ham1', 0.5), ('ham2', 0.5)]
balancer = balancing.WeightedRoundRobin(*servers)
count = 201
while count > 0:
count -= 1
_ = balancer.get_next()
self.assertEqual(1, sum(balancer.load))
self.assertEqual([1, 0], balancer.load)
def test___repr__(self):
balancer = balancing.WeightedRoundRobin(*[('ham1', 0.2), ('ham2', 0.8)])
exp = ("<class 'mysql.connector.fabric.balancing.WeightedRoundRobin'>"
"(load=[0, 0], ratios=[400, 100])")
self.assertEqual(exp, repr(balancer))
def test___eq__(self):
servers = [('ham1', 0.2), ('ham2', 0.8)]
balancer1 = balancing.WeightedRoundRobin(*servers)
balancer2 = balancing.WeightedRoundRobin(*servers)
self.assertTrue(balancer1 == balancer2)
servers = [('ham1', 0.2), ('ham2', 0.3), ('ham3', 0.5)]
balancer3 = balancing.WeightedRoundRobin(*servers)
self.assertFalse(balancer1 == balancer3)
@unittest.skipIf(not tests.FABRIC_CONFIG, ERR_NO_FABRIC_CONFIG)
class FabricShardingTests(tests.MySQLConnectorTests):
"""Test Fabric's sharding"""
emp_data = {
1985: [
(10001, datetime.date(1953, 9, 2), u'Georgi', u'Facello', u'M',
datetime.date(1986, 6, 26)),
(10002, datetime.date(1964, 6, 2), u'Bezalel', u'Simmel', u'F',
datetime.date(1985, 11, 21)),
],
2000: [
(47291, datetime.date(1960, 9, 9), u'Ulf', u'Flexer', u'M',
datetime.date(2000, 1, 12)),
(60134, datetime.date(1964, 4, 21), u'Seshu', u'Rathonyi', u'F',
datetime.date(2000, 1, 2)),
]
}
def setUp(self):
self.cnx = mysql.connector.connect(
fabric=tests.FABRIC_CONFIG, user='root', database='employees'
)
def _check_table(self, table, shard_type):
fabric = self.cnx._fabric
fab_set = fabric.execute("sharding", "lookup_table", table)
found = False
if fab_set.rowcount:
for row in fab_set.rows():
if (row.table_name == table and row.type_name == shard_type):
found = True
break
if found == False:
raise ValueError(
"Table {table} not found or wrong sharding type".format(
table=table))
return True
def _populate(self, cnx, wait_gtid, table, insert, data, shard_key_index):
for employee in data:
cnx.set_property(tables=["employees." + table],
key=employee[shard_key_index],
scope=fabric.SCOPE_LOCAL,
mode=fabric.MODE_READWRITE)
cur = cnx.cursor()
wait_for_gtid(cur, wait_gtid)
cur.execute(insert, employee)
cnx.commit()
def _truncate(self, cur, table):
cur.execute("TRUNCATE TABLE {0}".format(table))
cur.execute("SELECT @@global.gtid_executed")
return cur.fetchone()[0]
def test_range(self):
self.assertTrue(self._check_table("employees.employees_range", 'RANGE'))
tbl_name = "employees_range"
tables = ["employees.{0}".format(tbl_name)]
self.cnx.set_property(tables=tables,
scope=fabric.SCOPE_GLOBAL,
mode=fabric.MODE_READWRITE)
cur = self.cnx.cursor()
gtid_executed = self._truncate(cur, tbl_name)
self.cnx.commit()
insert = ("INSERT INTO {0} "
"VALUES (%s, %s, %s, %s, %s, %s)").format(tbl_name)
self._populate(self.cnx, gtid_executed, tbl_name, insert,
self.emp_data[1985] + self.emp_data[2000], 0)
time.sleep(2)
# Year is key of self.emp_data, second value is emp_no for RANGE key
exp_keys = [(1985, 10002), (2000, 47291)]
for year, emp_no in exp_keys:
self.cnx.set_property(tables=tables,
scope=fabric.SCOPE_LOCAL,
key=emp_no, mode=fabric.MODE_READONLY)
cur = self.cnx.cursor()
cur.execute("SELECT * FROM {0}".format(tbl_name))
rows = cur.fetchall()
self.assertEqual(rows, self.emp_data[year])
self.cnx.set_property(tables=tables,
key='spam', mode=fabric.MODE_READONLY)
self.assertRaises(ValueError, self.cnx.cursor)
def test_range_datetime(self):
self.assertTrue(self._check_table(
"employees.employees_range_datetime", 'RANGE_DATETIME'))
tbl_name = "employees_range_datetime"
tables = ["employees.{0}".format(tbl_name)]
self.cnx.set_property(tables=tables,
scope=fabric.SCOPE_GLOBAL,
mode=fabric.MODE_READWRITE)
cur = self.cnx.cursor()
gtid_executed = self._truncate(cur, tbl_name)
self.cnx.commit()
insert = ("INSERT INTO {0} "
"VALUES (%s, %s, %s, %s, %s, %s)").format(tbl_name)
self._populate(self.cnx, gtid_executed, tbl_name, insert,
self.emp_data[1985] + self.emp_data[2000], 5)
time.sleep(2)
hire_dates = [datetime.date(1985, 1, 1), datetime.date(2000, 1, 1)]
for hire_date in hire_dates:
self.cnx.set_property(tables=tables,
key=hire_date, mode=fabric.MODE_READONLY)
cur = self.cnx.cursor()
cur.execute("SELECT * FROM {0}".format(tbl_name))
rows = cur.fetchall()
self.assertEqual(rows, self.emp_data[hire_date.year])
self.cnx.set_property(tables=tables,
key='2014-01-02', mode=fabric.MODE_READONLY)
self.assertRaises(ValueError, self.cnx.cursor)
def test_range_string(self):
self.assertTrue(self._check_table(
"employees.employees_range_string", 'RANGE_STRING'))
tbl_name = "employees_range_string"
tables = ["employees.{0}".format(tbl_name)]
self.cnx.set_property(tables=tables,
scope=fabric.SCOPE_GLOBAL,
mode=fabric.MODE_READWRITE)
cur = self.cnx.cursor()
gtid_executed = self._truncate(cur, tbl_name)
self.cnx.commit()
insert = ("INSERT INTO {0} "
"VALUES (%s, %s, %s, %s, %s, %s)").format(tbl_name)
self._populate(self.cnx, gtid_executed, tbl_name, insert,
self.emp_data[1985] + self.emp_data[2000], 3)
time.sleep(2)
emp_exp_range_string = {
'A': [self.emp_data[1985][0],
self.emp_data[2000][0]],
'M': [self.emp_data[1985][1],
self.emp_data[2000][1]],
}
str_keys = [u'A', u'M']
for str_key in str_keys:
self.cnx.set_property(tables=tables,
key=str_key, mode=fabric.MODE_READONLY)
cur = self.cnx.cursor()
cur.execute("SELECT * FROM {0}".format(tbl_name))
rows = cur.fetchall()
self.assertEqual(rows, emp_exp_range_string[str_key])
self.cnx.set_property(tables=tables,
key=b'not unicode str', mode=fabric.MODE_READONLY)
self.assertRaises(ValueError, self.cnx.cursor)
self.cnx.set_property(tables=tables,
key=12345, mode=fabric.MODE_READONLY)
self.assertRaises(ValueError, self.cnx.cursor)
if PY2:
self.cnx.set_property(tables=tables,
key='not unicode str',
mode=fabric.MODE_READONLY)
self.assertRaises(ValueError, self.cnx.cursor)
def test_bug19642249(self):
self.assertTrue(self._check_table(
"employees.employees_range_string", 'RANGE_STRING'))
# Invalid key for RANGE_STRING
tbl_name = "employees_range_string"
tables = ["employees.{0}".format(tbl_name)]
self.cnx.set_property(tables=tables,
key=u'1', mode=fabric.MODE_READONLY)
try:
cur = self.cnx.cursor()
except ValueError as exc:
self.assertEqual("Key invalid; was '1'", str(exc))
else:
self.fail("ValueError not raised")
# Invalid key for RANGE_DATETIME
tbl_name = "employees_range_datetime"
tables = ["employees.{0}".format(tbl_name)]
self.cnx.set_property(tables=tables,
key=datetime.date(1977, 1, 1),
mode=fabric.MODE_READONLY)
try:
cur = self.cnx.cursor()
except ValueError as exc:
self.assertEqual("Key invalid; was '1977-01-01'", str(exc))
else:
self.fail("ValueError not raised")
def test_bug19331658(self):
"""Pooling not working with fabric
"""
self.assertRaises(
AttributeError, mysql.connector.connect,
fabric=tests.FABRIC_CONFIG, user='root', database='employees',
pool_name='mypool')
pool_size = 2
cnx = mysql.connector.connect(
fabric=tests.FABRIC_CONFIG, user='root', database='employees',
pool_size=pool_size, pool_reset_session=False
)
tbl_name = "employees_range"
tables = ["employees.{0}".format(tbl_name)]
cnx.set_property(tables=tables,
scope=fabric.SCOPE_GLOBAL,
mode=fabric.MODE_READWRITE)
cnx.cursor()
self.assertTrue(isinstance(cnx._mysql_cnx, PooledMySQLConnection))
data = self.emp_data[1985]
for emp in data:
cnx.set_property(tables=tables,
key=emp[0],
scope=fabric.SCOPE_LOCAL,
mode=fabric.MODE_READWRITE)
cnx.cursor()
mysqlserver = cnx._fabric_mysql_server
config = cnx._mysql_config
self.assertEqual(
cnx._mysql_cnx.pool_name, "{0}_{1}_{2}_{3}".format(
mysqlserver.host, mysqlserver.port, config['user'],
config['database'])
)
def test_range_hash(self):
self.assertTrue(self._check_table(
"employees.employees_hash", 'HASH'))
tbl_name = "employees_hash"
tables = ["employees.{0}".format(tbl_name)]
self.cnx.set_property(tables=tables,
scope=fabric.SCOPE_GLOBAL,
mode=fabric.MODE_READWRITE)
cur = self.cnx.cursor()
gtid_executed = self._truncate(cur, tbl_name)
self.cnx.commit()
insert = ("INSERT INTO {0} "
"VALUES (%s, %s, %s, %s, %s, %s)").format(tbl_name)
self._populate(self.cnx, gtid_executed, tbl_name, insert,
self.emp_data[1985] + self.emp_data[2000], 3)
time.sleep(2)
emp_exp_hash = self.emp_data[1985] + self.emp_data[2000]
rows = []
self.cnx.reset_properties()
str_keys = ['group1', 'group2']
for str_key in str_keys:
self.cnx.set_property(group=str_key, mode=fabric.MODE_READONLY)
cur = self.cnx.cursor()
cur.execute("SELECT * FROM {0}".format(tbl_name))
rows += cur.fetchall()
self.assertEqual(rows, emp_exp_hash)

View File

@@ -1,118 +0,0 @@
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2012, 2014, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unittests for mysql.connector.locales
"""
from datetime import datetime
import tests
from . import PY2
from mysql.connector import errorcode, locales
def _get_client_errors():
errors = {}
for name in dir(errorcode):
if name.startswith('CR_'):
errors[name] = getattr(errorcode, name)
return errors
class LocalesModulesTests(tests.MySQLConnectorTests):
def test_defaults(self):
# There should always be 'eng'
try:
from mysql.connector.locales import eng # pylint: disable=W0612
except ImportError:
self.fail("locales.eng could not be imported")
# There should always be 'eng.client_error'
some_error = None
try:
from mysql.connector.locales.eng import client_error
some_error = client_error.CR_UNKNOWN_ERROR
except ImportError:
self.fail("locales.eng.client_error could not be imported")
some_error = some_error + '' # fool pylint
def test_get_client_error(self):
try:
locales.get_client_error(2000, language='spam')
except ImportError as err:
self.assertEqual("No localization support for language 'spam'",
str(err))
else:
self.fail("ImportError not raised")
exp = "Unknown MySQL error"
self.assertEqual(exp, locales.get_client_error(2000))
self.assertEqual(exp, locales.get_client_error('CR_UNKNOWN_ERROR'))
try:
locales.get_client_error(tuple())
except ValueError as err:
self.assertEqual(
"error argument needs to be either an integer or string",
str(err))
else:
self.fail("ValueError not raised")
class LocalesEngClientErrorTests(tests.MySQLConnectorTests):
"""Testing locales.eng.client_error"""
def test__MYSQL_VERSION(self):
try:
from mysql.connector.locales.eng import client_error
except ImportError:
self.fail("locales.eng.client_error could not be imported")
minimum = (5, 6, 6)
self.assertTrue(isinstance(client_error._MYSQL_VERSION, tuple))
self.assertTrue(len(client_error._MYSQL_VERSION) == 3)
self.assertTrue(client_error._MYSQL_VERSION >= minimum)
def test_messages(self):
try:
from mysql.connector.locales.eng import client_error
except ImportError:
self.fail("locales.eng.client_error could not be imported")
errors = _get_client_errors()
count = 0
for name in dir(client_error):
if name.startswith('CR_'):
count += 1
self.assertEqual(len(errors), count)
if PY2:
strtype = unicode # pylint: disable=E0602
else:
strtype = str
for name in errors.keys():
self.assertTrue(isinstance(getattr(client_error, name), strtype))

View File

@@ -1,417 +0,0 @@
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2009, 2014, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unittests for MySQL data types
"""
from decimal import Decimal
import time
import datetime
from mysql.connector import connection, errors
import tests
from tests import foreach_cnx, cnx_config
try:
from mysql.connector.connection_cext import CMySQLConnection
except ImportError:
# Test without C Extension
CMySQLConnection = None
def _get_insert_stmt(tbl, cols):
insert = "INSERT INTO {table} ({columns}) values ({values})".format(
table=tbl,
columns=','.join(cols),
values=','.join(['%s'] * len(cols))
)
return insert
def _get_select_stmt(tbl, cols):
select = "SELECT {columns} FROM {table} ORDER BY id".format(
columns=','.join(cols),
table=tbl
)
return select
class TestsDataTypes(tests.MySQLConnectorTests):
tables = {
'bit': 'myconnpy_mysql_bit',
'int': 'myconnpy_mysql_int',
'bool': 'myconnpy_mysql_bool',
'float': 'myconnpy_mysql_float',
'decimal': 'myconnpy_mysql_decimal',
'temporal': 'myconnpy_mysql_temporal',
'temporal_year': 'myconnpy_mysql_temporal_year',
'set': 'myconnpy_mysql_set',
}
def compare(self, name, val1, val2):
self.assertEqual(val1, val2, "%s %s != %s" % (name, val1, val2))
def drop_tables(self, cnx):
cur = cnx.cursor()
table_names = self.tables.values()
cur.execute("DROP TABLE IF EXISTS {tables}".format(
tables=','.join(table_names))
)
cur.close()
class TestsCursor(TestsDataTypes):
def setUp(self):
pass
#self.config = tests.get_mysql_config()
#cnx = connection.MySQLConnection(**self.config)
#self.drop_tables(cnx)
def tearDown(self):
pass
#cnx = connection.MySQLConnection(**self.config)
#self.drop_tables(cnx)
#cnx.close()
@foreach_cnx()
def test_numeric_int(self):
tbl = self.tables['int']
self.cnx.cmd_query("DROP TABLE IF EXISTS {0}".format(tbl))
cur = self.cnx.cursor()
columns = [
'tinyint_signed',
'tinyint_unsigned',
'bool_signed',
'smallint_signed',
'smallint_unsigned',
'mediumint_signed',
'mediumint_unsigned',
'int_signed',
'int_unsigned',
'bigint_signed',
'bigint_unsigned',
]
cur.execute((
"CREATE TABLE {table} ("
"`id` TINYINT UNSIGNED NOT NULL AUTO_INCREMENT,"
"`tinyint_signed` TINYINT SIGNED,"
"`tinyint_unsigned` TINYINT UNSIGNED,"
"`bool_signed` BOOL,"
"`smallint_signed` SMALLINT SIGNED,"
"`smallint_unsigned` SMALLINT UNSIGNED,"
"`mediumint_signed` MEDIUMINT SIGNED,"
"`mediumint_unsigned` MEDIUMINT UNSIGNED,"
"`int_signed` INT SIGNED,"
"`int_unsigned` INT UNSIGNED,"
"`bigint_signed` BIGINT SIGNED,"
"`bigint_unsigned` BIGINT UNSIGNED,"
"PRIMARY KEY (id))"
).format(table=tbl)
)
data = [
(
-128, # tinyint signed
0, # tinyint unsigned
0, # boolean
-32768, # smallint signed
0, # smallint unsigned
-8388608, # mediumint signed
0, # mediumint unsigned
-2147483648, # int signed
0, # int unsigned
-9223372036854775808, # big signed
0, # big unsigned
),
(
127, # tinyint signed
255, # tinyint unsigned
127, # boolean
32767, # smallint signed
65535, # smallint unsigned
8388607, # mediumint signed
16777215, # mediumint unsigned
2147483647, # int signed
4294967295, # int unsigned
9223372036854775807, # big signed
18446744073709551615, # big unsigned
)
]
insert = _get_insert_stmt(tbl, columns)
select = _get_select_stmt(tbl, columns)
cur.executemany(insert, data)
cur.execute(select)
rows = cur.fetchall()
for i, col in enumerate(columns):
self.compare(col, data[0][i], rows[0][i])
self.compare(col, data[1][i], rows[1][i])
cur.close()
@foreach_cnx()
def test_numeric_bit(self):
tbl = self.tables['bit']
self.cnx.cmd_query("DROP TABLE IF EXISTS {0}".format(tbl))
cur = self.cnx.cursor()
columns = [
'c8', 'c16', 'c24', 'c32',
'c40', 'c48', 'c56', 'c63',
'c64']
cur.execute((
"CREATE TABLE {table} ("
"`id` int NOT NULL AUTO_INCREMENT,"
"`c8` bit(8) DEFAULT NULL,"
"`c16` bit(16) DEFAULT NULL,"
"`c24` bit(24) DEFAULT NULL,"
"`c32` bit(32) DEFAULT NULL,"
"`c40` bit(40) DEFAULT NULL,"
"`c48` bit(48) DEFAULT NULL,"
"`c56` bit(56) DEFAULT NULL,"
"`c63` bit(63) DEFAULT NULL,"
"`c64` bit(64) DEFAULT NULL,"
"PRIMARY KEY (id))"
).format(table=tbl)
)
insert = _get_insert_stmt(tbl, columns)
select = _get_select_stmt(tbl, columns)
data = list()
data.append(tuple([0] * len(columns)))
values = list()
for col in columns:
values.append((1 << int(col.replace('c', ''))) - 1)
data.append(tuple(values))
values = list()
for col in columns:
bits = int(col.replace('c', ''))
values.append((1 << bits) - 1)
data.append(tuple(values))
cur.executemany(insert, data)
cur.execute(select)
rows = cur.fetchall()
self.assertEqual(rows, data)
cur.close()
@foreach_cnx()
def test_numeric_float(self):
tbl = self.tables['float']
self.cnx.cmd_query("DROP TABLE IF EXISTS {0}".format(tbl))
cur = self.cnx.cursor()
columns = [
'float_signed',
'float_unsigned',
'double_signed',
'double_unsigned',
]
cur.execute((
"CREATE TABLE {table} ("
"`id` int NOT NULL AUTO_INCREMENT,"
"`float_signed` FLOAT(6,5) SIGNED,"
"`float_unsigned` FLOAT(6,5) UNSIGNED,"
"`double_signed` DOUBLE(15,10) SIGNED,"
"`double_unsigned` DOUBLE(15,10) UNSIGNED,"
"PRIMARY KEY (id))"
).format(table=tbl)
)
insert = _get_insert_stmt(tbl, columns)
select = _get_select_stmt(tbl, columns)
data = [
(-3.402823466, 0, -1.7976931348623157, 0,),
(-1.175494351, 3.402823466,
1.7976931348623157, 2.2250738585072014),
(-1.23455678, 2.999999, -1.3999999999999999, 1.9999999999999999),
]
cur.executemany(insert, data)
cur.execute(select)
rows = cur.fetchall()
for j in range(0, len(data)):
for i, col in enumerate(columns[0:2]):
self.compare(col, round(data[j][i], 5), rows[j][i])
for i, col in enumerate(columns[2:2]):
self.compare(col, round(data[j][i], 10), rows[j][i])
cur.close()
@foreach_cnx()
def test_numeric_decimal(self):
tbl = self.tables['decimal']
self.cnx.cmd_query("DROP TABLE IF EXISTS {0}".format(tbl))
cur = self.cnx.cursor()
columns = [
'decimal_signed',
'decimal_unsigned',
]
cur.execute((
"CREATE TABLE {table} ("
"`id` int NOT NULL AUTO_INCREMENT,"
"`decimal_signed` DECIMAL(65,30) SIGNED,"
"`decimal_unsigned` DECIMAL(65,30) UNSIGNED,"
"PRIMARY KEY (id))"
).format(table=tbl)
)
insert = _get_insert_stmt(tbl, columns)
select = _get_select_stmt(tbl, columns)
data = [
(Decimal(
'-9999999999999999999999999.999999999999999999999999999999'),
Decimal(
'+9999999999999999999999999.999999999999999999999999999999')),
(Decimal('-1234567.1234'),
Decimal('+123456789012345.123456789012345678901')),
(Decimal(
'-1234567890123456789012345.123456789012345678901234567890'),
Decimal(
'+1234567890123456789012345.123456789012345678901234567890')),
]
cur.executemany(insert, data)
cur.execute(select)
rows = cur.fetchall()
self.assertEqual(data, rows)
cur.close()
@foreach_cnx()
def test_temporal_datetime(self):
tbl = self.tables['temporal']
self.cnx.cmd_query("DROP TABLE IF EXISTS {0}".format(tbl))
cur = self.cnx.cursor()
cur.execute("SET SESSION time_zone = '+00:00'")
columns = [
't_date',
't_datetime',
't_time',
't_timestamp',
't_year_4',
]
cur.execute((
"CREATE TABLE {table} ("
"`id` int NOT NULL AUTO_INCREMENT,"
"`t_date` DATE,"
"`t_datetime` DATETIME,"
"`t_time` TIME,"
"`t_timestamp` TIMESTAMP DEFAULT 0,"
"`t_year_4` YEAR(4),"
"PRIMARY KEY (id))"
).format(table=tbl)
)
insert = _get_insert_stmt(tbl, columns)
select = _get_select_stmt(tbl, columns)
data = [
(datetime.date(2010, 1, 17),
datetime.datetime(2010, 1, 17, 19, 31, 12),
datetime.timedelta(hours=43, minutes=32, seconds=21),
datetime.datetime(2010, 1, 17, 19, 31, 12),
0),
(datetime.date(1000, 1, 1),
datetime.datetime(1000, 1, 1, 0, 0, 0),
datetime.timedelta(hours=-838, minutes=59, seconds=59),
datetime.datetime(*time.gmtime(1)[:6]),
1901),
(datetime.date(9999, 12, 31),
datetime.datetime(9999, 12, 31, 23, 59, 59),
datetime.timedelta(hours=838, minutes=59, seconds=59),
datetime.datetime(2038, 1, 19, 3, 14, 7),
2155),
]
cur.executemany(insert, data)
cur.execute(select)
rows = cur.fetchall()
for j in (range(0, len(data))):
for i, col in enumerate(columns):
self.compare("{column} (data[{count}])".format(
column=col, count=j), data[j][i], rows[j][i])
# Testing YEAR(2), which is now obsolete since MySQL 5.6.6
tblname = self.tables['temporal_year']
cur.execute("DROP TABLE IF EXISTS {0}".format(tblname))
stmt = (
"CREATE TABLE {table} ("
"`id` int NOT NULL AUTO_INCREMENT KEY, "
"`t_year_2` YEAR(2))".format(table=tblname)
)
if tests.MYSQL_VERSION >= (5, 7, 5):
# Support for YEAR(2) removed in MySQL 5.7.5
self.assertRaises(errors.DatabaseError, cur.execute, stmt)
else:
cur.execute(stmt)
cur.execute(_get_insert_stmt(tblname, ['t_year_2']), (10,))
cur.execute(_get_select_stmt(tblname, ['t_year_2']))
row = cur.fetchone()
if tests.MYSQL_VERSION >= (5, 6, 6):
self.assertEqual(2010, row[0])
else:
self.assertEqual(10, row[0])
cur.close()
@cnx_config(consume_results=True)
@foreach_cnx()
def test_set(self):
tbl = self.tables['temporal']
self.cnx.cmd_query("DROP TABLE IF EXISTS {0}".format(tbl))
cur = self.cnx.cursor()
cur.execute((
"CREATE TABLE {table} ("
"`id` int NOT NULL AUTO_INCREMENT,"
"c1 SET ('a', 'b', 'c'),"
"c2 SET ('1', '2', '3'),"
"c3 SET ('ham', 'spam'),"
"PRIMARY KEY (id))"
).format(table=tbl)
)
insert = (
"INSERT INTO {table} (c1, c2, c3) VALUES "
"('a,c', '1,3', 'spam'), ('b', '3,2', 'spam,spam,ham')"
).format(table=tbl)
cur.execute(insert)
cur.execute("SELECT * FROM {table}".format(table=tbl))
exp = [
(1, set([u'a', u'c']), set([u'1', u'3']), set([u'spam'])),
(2, set([u'b']), set([u'3', u'2']), set([u'ham', u'spam']))
]
self.assertEqual(exp, cur.fetchall())
cur.close()

View File

@@ -1,427 +0,0 @@
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2012, 2016, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unittests for mysql.connector.network
"""
import os
import socket
import logging
from collections import deque
import unittest
import tests
from mysql.connector import (network, errors, constants)
LOGGER = logging.getLogger(tests.LOGGER_NAME)
class NetworkTests(tests.MySQLConnectorTests):
"""Testing mysql.connector.network functions"""
def test__prepare_packets(self):
"""Prepare packets for sending"""
data = (b'abcdefghijklmn', 1)
exp = [b'\x0e\x00\x00\x01abcdefghijklmn']
self.assertEqual(exp, network._prepare_packets(*(data)))
data = (b'a' * (constants.MAX_PACKET_LENGTH + 1000), 2)
exp = [
b'\xff\xff\xff\x02' + (b'a' * constants.MAX_PACKET_LENGTH),
b'\xe8\x03\x00\x03' + (b'a' * 1000)
]
self.assertEqual(exp, network._prepare_packets(*(data)))
class BaseMySQLSocketTests(tests.MySQLConnectorTests):
"""Testing mysql.connector.network.BaseMySQLSocket"""
def setUp(self):
config = tests.get_mysql_config()
self._host = config['host']
self._port = config['port']
self.cnx = network.BaseMySQLSocket()
def tearDown(self):
try:
self.cnx.close_connection()
except:
pass
def _get_socket(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
LOGGER.debug("Get socket for {host}:{port}".format(
host=self._host, port=self._port))
sock.connect((self._host, self._port))
return sock
def test_init(self):
"""MySQLSocket initialization"""
exp = {
'sock': None,
'_connection_timeout': None,
'_packet_queue': deque(),
'recvsize': 1024 * 8,
}
for key, value in exp.items():
self.assertEqual(value, self.cnx.__dict__[key])
def test_next_packet_number(self):
"""Test packet number property"""
self.assertEqual(0, self.cnx.next_packet_number)
self.assertEqual(0, self.cnx._packet_number)
self.assertEqual(1, self.cnx.next_packet_number)
self.assertEqual(1, self.cnx._packet_number)
self.cnx._packet_number = 255
self.assertEqual(0, self.cnx.next_packet_number)
def test_open_connection(self):
"""Opening a connection"""
self.assertRaises(NotImplementedError, self.cnx.open_connection)
def test_get_address(self):
"""Get the address of a connection"""
self.assertRaises(NotImplementedError, self.cnx.get_address)
def test_shutdown(self):
"""Shutting down a connection"""
self.cnx.shutdown()
self.assertEqual(None, self.cnx.sock)
def test_close_connection(self):
"""Closing a connection"""
self.cnx.close_connection()
self.assertEqual(None, self.cnx.sock)
def test_send_plain(self):
"""Send plain data through the socket"""
data = b'asddfasdfasdf'
self.assertRaises(errors.OperationalError, self.cnx.send_plain,
data, 0)
self.cnx.sock = tests.DummySocket()
data = [
(b'\x03\x53\x45\x4c\x45\x43\x54\x20\x22\x61\x62\x63\x22', 1),
(b'\x03\x53\x45\x4c\x45\x43\x54\x20\x22'
+ (b'\x61' * (constants.MAX_PACKET_LENGTH + 1000)) + b'\x22', 2)]
self.assertRaises(Exception, self.cnx.send_plain, None, None)
for value in data:
exp = network._prepare_packets(*value)
try:
self.cnx.send_plain(*value)
except errors.Error as err:
self.fail("Failed sending pktnr {}: {}".format(value[1],
str(err)))
self.assertEqual(exp, self.cnx.sock._client_sends)
self.cnx.sock.reset()
def test_send_compressed(self):
"""Send compressed data through the socket"""
data = b'asddfasdfasdf'
self.assertRaises(errors.OperationalError, self.cnx.send_compressed,
data, 0)
self.cnx.sock = tests.DummySocket()
self.assertRaises(Exception, self.cnx.send_compressed, None, None)
# Small packet
data = (b'\x03\x53\x45\x4c\x45\x43\x54\x20\x22\x61\x62\x63\x22', 1)
exp = [b'\x11\x00\x00\x02\x00\x00\x00\r\x00\x00\x01\x03SELECT "abc"']
try:
self.cnx.send_compressed(*data)
except errors.Error as err:
self.fail("Failed sending pktnr {}: {}".format(data[1], err))
self.assertEqual(exp, self.cnx.sock._client_sends)
self.cnx.sock.reset()
# Slightly bigger packet (not getting compressed)
data = (b'\x03\x53\x45\x4c\x45\x43\x54\x20\x22\x61\x62\x63\x22', 1)
exp = (24, b'\x11\x00\x00\x03\x00\x00\x00\x0d\x00\x00\x01\x03'
b'\x53\x45\x4c\x45\x43\x54\x20\x22')
try:
self.cnx.send_compressed(*data)
except errors.Error as err:
self.fail("Failed sending pktnr {}: {}".format(data[1], str(err)))
received = self.cnx.sock._client_sends[0]
self.assertEqual(exp, (len(received), received[:20]))
self.cnx.sock.reset()
# Big packet
data = (b'\x03\x53\x45\x4c\x45\x43\x54\x20\x22'
+ b'\x61' * (constants.MAX_PACKET_LENGTH + 1000) + b'\x22', 2)
exp = [
(63, b'\x38\x00\x00\x04\x00\x40\x00\x78\x9c\xed\xc1\x31'
b'\x0d\x00\x20\x0c\x00\xb0\x04\x8c'),
(16322, b'\xbb\x3f\x00\x05\xf9\xc3\xff\x78\x9c\xec\xc1\x81'
b'\x00\x00\x00\x00\x80\x20\xd6\xfd')]
try:
self.cnx.send_compressed(*data)
except errors.Error as err:
self.fail("Failed sending pktnr {}: {}".format(data[1], str(err)))
received = [(len(r), r[:20]) for r in self.cnx.sock._client_sends]
self.assertEqual(exp, received)
self.cnx.sock.reset()
def test_recv_plain(self):
"""Receive data from the socket"""
self.cnx.sock = tests.DummySocket()
def get_address():
return 'dummy'
self.cnx.get_address = get_address
# Receive a packet which is not 4 bytes long
self.cnx.sock.add_packet(b'\01\01\01')
self.assertRaises(errors.InterfaceError, self.cnx.recv_plain)
# Socket fails to receive and produces an error
self.cnx.sock.raise_socket_error()
self.assertRaises(errors.OperationalError, self.cnx.recv_plain)
# Receive packets after a query, SELECT "Ham"
exp = [
b'\x01\x00\x00\x01\x01',
b'\x19\x00\x00\x02\x03\x64\x65\x66\x00\x00\x00\x03\x48\x61\x6d\x00'
b'\x0c\x21\x00\x09\x00\x00\x00\xfd\x01\x00\x1f\x00\x00',
b'\x05\x00\x00\x03\xfe\x00\x00\x02\x00',
b'\x04\x00\x00\x04\x03\x48\x61\x6d',
b'\x05\x00\x00\x05\xfe\x00\x00\x02\x00',
]
self.cnx.sock.reset()
self.cnx.sock.add_packets(exp)
length_exp = len(exp)
result = []
packet = self.cnx.recv_plain()
while packet:
result.append(packet)
if length_exp == len(result):
break
packet = self.cnx.recv_plain()
self.assertEqual(exp, result)
def test_recv_compressed(self):
"""Receive compressed data from the socket"""
self.cnx.sock = tests.DummySocket()
def get_address():
return 'dummy'
self.cnx.get_address = get_address
# Receive a packet which is not 7 bytes long
self.cnx.sock.add_packet(b'\01\01\01\01\01\01')
self.assertRaises(errors.InterfaceError, self.cnx.recv_compressed)
# Receive the header of a packet, but nothing more
self.cnx.sock.add_packet(b'\01\00\00\00\00\00\00')
self.assertRaises(errors.InterfaceError, self.cnx.recv_compressed)
# Socket fails to receive and produces an error
self.cnx.sock.raise_socket_error()
self.assertRaises(errors.OperationalError, self.cnx.recv_compressed)
def test_set_connection_timeout(self):
"""Set the connection timeout"""
exp = 5
self.cnx.set_connection_timeout(exp)
self.assertEqual(exp, self.cnx._connection_timeout)
@unittest.skipIf(os.name == 'nt', "Skip UNIX Socket tests on Windows")
class MySQLUnixSocketTests(tests.MySQLConnectorTests):
"""Testing mysql.connector.network.MySQLUnixSocket"""
def setUp(self):
config = tests.get_mysql_config()
self._unix_socket = config['unix_socket']
self.cnx = network.MySQLUnixSocket(unix_socket=config['unix_socket'])
def tearDown(self):
try:
self.cnx.close_connection()
except:
pass
def test_init(self):
"""MySQLUnixSocket initialization"""
exp = {
'unix_socket': self._unix_socket,
}
for key, value in exp.items():
self.assertEqual(value, self.cnx.__dict__[key])
def test_get_address(self):
"""Get path to the Unix socket"""
exp = self._unix_socket
self.assertEqual(exp, self.cnx.get_address())
def test_open_connection(self):
"""Open a connection using a Unix socket"""
if os.name == 'nt':
self.assertRaises(errors.InterfaceError, self.cnx.open_connection)
else:
try:
self.cnx.open_connection()
except errors.Error as err:
self.fail(str(err))
@unittest.skipIf(not tests.SSL_AVAILABLE,
"Could not test switch to SSL. Make sure Python supports "
"SSL.")
def test_switch_to_ssl(self):
"""Switch the socket to use SSL"""
args = {
'ca': os.path.join(tests.SSL_DIR, 'tests_CA_cert.pem'),
'cert': os.path.join(tests.SSL_DIR, 'tests_client_cert.pem'),
'key': os.path.join(tests.SSL_DIR, 'tests_client_key.pem'),
'cipher': 'AES256-SHA'
}
self.assertRaises(errors.InterfaceError,
self.cnx.switch_to_ssl, **args)
# Handshake failure
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(4)
sock.connect(self._unix_socket)
self.cnx.sock = sock
self.assertRaises(errors.InterfaceError,
self.cnx.switch_to_ssl, **args)
class MySQLTCPSocketTests(tests.MySQLConnectorTests):
"""Testing mysql.connector.network..MySQLTCPSocket"""
def setUp(self):
config = tests.get_mysql_config()
self._host = config['host']
self._port = config['port']
self.cnx = network.MySQLTCPSocket(host=self._host, port=self._port)
def tearDown(self):
try:
self.cnx.close_connection()
except:
pass
def test_init(self):
"""MySQLTCPSocket initialization"""
exp = {
'server_host': self._host,
'server_port': self._port,
}
for key, value in exp.items():
self.assertEqual(value, self.cnx.__dict__[key])
def test_get_address(self):
"""Get TCP/IP address"""
exp = "%s:%s" % (self._host, self._port)
self.assertEqual(exp, self.cnx.get_address())
@unittest.skipIf(tests.IPV6_AVAILABLE, "Testing IPv6, not testing IPv4")
def test_open_connection__ipv4(self):
"""Open a connection using TCP"""
try:
self.cnx.open_connection()
except errors.Error as err:
self.fail(str(err))
config = tests.get_mysql_config()
self._host = config['host']
self._port = config['port']
cases = [
# Address, Expected Family, Should Raise, Force IPv6
(tests.get_mysql_config()['host'], socket.AF_INET, False, False),
]
for case in cases:
self._test_open_connection(*case)
@unittest.skipIf(not tests.IPV6_AVAILABLE, "IPv6 testing disabled")
def test_open_connection__ipv6(self):
"""Open a connection using TCP"""
config = tests.get_mysql_config()
self._host = config['host']
self._port = config['port']
cases = [
# Address, Expected Family, Should Raise, Force IPv6
('::1', socket.AF_INET6, False, False),
('2001::14:06:77', socket.AF_INET6, True, False),
('xx:00:xx', socket.AF_INET6, True, False),
]
for case in cases:
self._test_open_connection(*case)
def _test_open_connection(self, addr, family, should_raise, force):
try:
sock = network.MySQLTCPSocket(host=addr,
port=self._port,
force_ipv6=force)
sock.set_connection_timeout(1)
sock.open_connection()
except (errors.InterfaceError, socket.error):
if not should_raise:
self.fail('{0} incorrectly raised socket.error'.format(
addr))
else:
if should_raise:
self.fail('{0} should have raised socket.error'.format(
addr))
else:
self.assertEqual(family, sock._family,
"Family for {0} did not match".format(
addr, family, sock._family))
sock.close_connection()
@unittest.skipIf(not tests.SSL_AVAILABLE,
"Could not test switch to SSL. Make sure Python supports "
"SSL.")
def test_switch_to_ssl(self):
"""Switch the socket to use SSL"""
args = {
'ca': os.path.join(tests.SSL_DIR, 'tests_CA_cert.pem'),
'cert': os.path.join(tests.SSL_DIR, 'tests_client_cert.pem'),
'key': os.path.join(tests.SSL_DIR, 'tests_client_key.pem'),
}
self.assertRaises(errors.InterfaceError,
self.cnx.switch_to_ssl, **args)
# Handshake failure
(family, socktype, proto, _,
sockaddr) = socket.getaddrinfo(self._host, self._port)[0]
sock = socket.socket(family, socktype, proto)
sock.settimeout(4)
sock.connect(sockaddr)
self.cnx.sock = sock
self.assertRaises(errors.InterfaceError,
self.cnx.switch_to_ssl, **args)

View File

@@ -1,211 +0,0 @@
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2014, 2016, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
import logging
import os
import tests
from mysql.connector import connect
from mysql.connector.optionfiles import MySQLOptionsParser, read_option_files
LOGGER = logging.getLogger(tests.LOGGER_NAME)
class MySQLOptionsParserTests(tests.MySQLConnectorTests):
"""Class checking MySQLOptionsParser"""
def setUp(self):
self.option_file_dir = os.path.join('tests', 'data', 'option_files')
self.option_file_parser = MySQLOptionsParser(files=os.path.join(
self.option_file_dir, 'my.cnf'))
def test___init__(self):
self.assertRaises(ValueError, MySQLOptionsParser)
option_file_parser = MySQLOptionsParser(files=os.path.join(
self.option_file_dir, 'my.cnf'))
self.assertEqual(option_file_parser.files, [os.path.join(
self.option_file_dir, 'my.cnf')])
def test_optionxform(self):
"""Converts option strings
Converts option strings to lower case and replaces dashes(-) with
underscores(_) if keep_dashes variable is set.
"""
self.assertEqual('ham', self.option_file_parser.optionxform('HAM'))
self.assertEqual('ham-spam', self.option_file_parser.optionxform(
'HAM-SPAM'))
self.option_file_parser.keep_dashes = False
self.assertEqual('ham_spam', self.option_file_parser.optionxform(
'HAM-SPAM'))
def test__parse_options(self):
files = [
os.path.join(self.option_file_dir, 'include_files', '1.cnf'),
os.path.join(self.option_file_dir, 'include_files', '2.cnf'),
]
self.option_file_parser = MySQLOptionsParser(files)
self.assertRaises(ValueError, self.option_file_parser._parse_options,
'dummy_file.cnf')
self.option_file_parser._parse_options(files)
exp = {
'option1': '15',
'option2': '20'
}
self.assertEqual(exp, self.option_file_parser.get_groups('group2',
'group1'))
exp = {
'option3': '200'
}
self.assertEqual(exp, self.option_file_parser.get_groups('group3',
'group4'))
self.assertEqual(exp, self.option_file_parser.get_groups('group4',
'group3'))
def test_read(self,):
filename = os.path.join( self.option_file_dir, 'my.cnf')
self.assertEqual([filename], self.option_file_parser.read(filename))
filenames = [
os.path.join(self.option_file_dir, 'include_files', '1.cnf'),
os.path.join(self.option_file_dir, 'include_files', '2.cnf'),
]
self.assertEqual(filenames, self.option_file_parser.read(filenames))
self.assertEqual([], self.option_file_parser.read('dummy-file.cnf'))
def test_get_groups(self):
exp = {
'password': '12345',
'port': '1001',
'socket': '/var/run/mysqld/mysqld2.sock',
'ssl-ca': 'dummyCA',
'ssl-cert': 'dummyCert',
'ssl-key': 'dummyKey',
'ssl-cipher': 'AES256-SHA:CAMELLIA256-SHA',
'nice': '0',
'user': 'mysql',
'pid-file': '/var/run/mysqld/mysqld.pid',
'basedir': '/usr',
'datadir': '/var/lib/mysql',
'tmpdir': '/tmp',
'lc-messages-dir': '/usr/share/mysql',
'skip-external-locking': '',
'bind-address': '127.0.0.1',
'log_error': '/var/log/mysql/error.log',
}
self.assertEqual(exp, self.option_file_parser.get_groups('client',
'mysqld_safe',
'mysqld'))
def test_get_groups_as_dict(self):
exp = dict([
('client', {'port': '1000',
'password': '12345',
'socket': '/var/run/mysqld/mysqld.sock',
'ssl-ca': 'dummyCA',
'ssl-cert': 'dummyCert',
'ssl-key': 'dummyKey',
'ssl-cipher': 'AES256-SHA:CAMELLIA256-SHA'}),
('mysqld_safe', {'socket': '/var/run/mysqld/mysqld1.sock',
'nice': '0'}),
('mysqld', {'user': 'mysql',
'pid-file': '/var/run/mysqld/mysqld.pid',
'socket': '/var/run/mysqld/mysqld2.sock',
'port': '1001', 'basedir': '/usr',
'datadir': '/var/lib/mysql', 'tmpdir': '/tmp',
'lc-messages-dir': '/usr/share/mysql',
'skip-external-locking': '',
'bind-address': '127.0.0.1',
'log_error': '/var/log/mysql/error.log'}),
])
self.assertEqual(exp, self.option_file_parser.get_groups_as_dict())
def test_get_groups_as_dict_with_priority(self):
files = [
os.path.join(self.option_file_dir, 'include_files', '1.cnf'),
os.path.join(self.option_file_dir, 'include_files', '2.cnf'),
]
self.option_file_parser = MySQLOptionsParser(files)
exp = dict([
('group1', {'option1': ('15', 1),
'option2': ('20', 1)}),
('group2', {'option1': ('20', 1),
'option2': ('30', 1)}),
('group3', {'option3': ('100', 0)}),
('group4', {'option3': ('200', 1)}),
('mysql', {'user': ('ham', 0)}),
('client', {'user': ('spam', 1)})
])
self.assertEqual(
exp, self.option_file_parser.get_groups_as_dict_with_priority())
def test_read_option_files(self):
self.assertRaises(ValueError, read_option_files,
option_files='dummy_file.cnf')
option_file_dir = os.path.join('tests', 'data', 'option_files')
exp = {
'password': '12345',
'port': 1000,
'unix_socket': '/var/run/mysqld/mysqld.sock',
'ssl_ca': 'dummyCA',
'ssl_cert': 'dummyCert',
'ssl_key': 'dummyKey',
'ssl_cipher': 'AES256-SHA:CAMELLIA256-SHA',
}
result = read_option_files(option_files=os.path.join(
option_file_dir, 'my.cnf'))
self.assertEqual(exp, result)
exp = {
'password': '12345',
'port': 1001,
'unix_socket': '/var/run/mysqld/mysqld2.sock',
'ssl_ca': 'dummyCA',
'ssl_cert': 'dummyCert',
'ssl_key': 'dummyKey',
'ssl_cipher': 'AES256-SHA:CAMELLIA256-SHA',
'user': 'mysql',
}
result = read_option_files(option_files=os.path.join(
option_file_dir, 'my.cnf'), option_groups=['client', 'mysqld'])
self.assertEqual(exp, result)
option_file_dir = os.path.join('tests', 'data', 'option_files')
files = [
os.path.join(option_file_dir, 'include_files', '1.cnf'),
os.path.join(option_file_dir, 'include_files', '2.cnf'),
]
exp = {
'user': 'spam'
}
result = read_option_files(option_files=files,
option_groups=['client', 'mysql'])
self.assertEqual(exp, result)
self.assertRaises(ValueError, connect, option_files='dummy_file.cnf')

View File

@@ -1,496 +0,0 @@
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2009, 2014, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unittests for PEP-249
Rewritten from scratch. Found Ian Bicking's test suite and shamelessly
stole few of his ideas. (Geert)
"""
import datetime
import time
import inspect
import tests
import mysql.connector as myconn
class PEP249Base(tests.MySQLConnectorTests):
def db_connect(self):
return myconn.connect(use_pure=True, **tests.get_mysql_config())
def get_connection_id(self, cursor):
cid = None
try:
cursor.execute("SELECT CONNECTION_ID()")
cid = cursor.fetchone()[0]
except myconn.errors.Error as err:
self.fail("Failed getting connection id; {0}".format(str(err)))
return cid
def setUp(self):
self.cnx = self.db_connect()
def tearDown(self):
self.cnx.close()
class PEP249ModuleTests(PEP249Base):
def setUp(self):
pass
def tearDown(self):
pass
def test_connect(self):
"""Interface exports the connect()-function"""
self.assertTrue(inspect.isfunction(myconn.connect),
"Module does not export the connect()-function")
cnx = myconn.connect(use_pure=True, **tests.get_mysql_config())
self.assertTrue(isinstance(cnx, myconn.connection.MySQLConnection),
"The connect()-method returns incorrect instance")
cnx = myconn.connect(**tests.get_mysql_config())
self.assertTrue(isinstance(cnx, myconn.connection.MySQLConnection),
"connect() not returning by default pure "
"MySQLConnection object")
if tests.MYSQL_CAPI:
cnx = myconn.connect(use_pure=False, **tests.get_mysql_config())
self.assertTrue(isinstance(cnx,
myconn.connection_cext.CMySQLConnection),
"The connect()-method returns incorrect instance")
def test_apilevel(self):
"""Interface sets the API level"""
self.assertTrue(hasattr(myconn, 'apilevel'),
"API level is not defined")
self.assertEqual('2.0', myconn.apilevel,
"API Level should be '2.0'")
def test_threadsafety(self):
"""Interface defines thread safety"""
self.assertTrue(myconn.threadsafety in (0, 1, 2, 3))
self.assertEqual(1, myconn.threadsafety)
def test_paramstyle(self):
"""Interface sets the parameter style"""
self.assertTrue(myconn.paramstyle in
('qmark', 'numeric', 'named', 'format', 'pyformat'),
"paramstyle was assigned an unsupported value")
self.assertEqual('pyformat', myconn.paramstyle,
"paramstyle should be 'pyformat'")
class PEP249ErrorsTests(PEP249Base):
def setUp(self):
pass
def tearDown(self):
pass
def test_Warning(self):
"""Interface exports the Warning-exception"""
self.assertTrue(issubclass(myconn.errors.Warning, Exception),
"Warning is not subclass of Exception")
def test_Error(self):
"""Interface exports the Error-exception"""
self.assertTrue(issubclass(myconn.errors.Error, Exception),
"Error is not subclass of Exception")
def test_InterfaceError(self):
"""Interface exports the InterfaceError-exception"""
self.assertTrue(issubclass(myconn.errors.InterfaceError,
myconn.errors.Error),
"InterfaceError is not subclass of errors.Error")
def test_DatabaseError(self):
"""Interface exports the DatabaseError-exception"""
self.assertTrue(issubclass(myconn.errors.DatabaseError,
myconn.errors.Error),
"DatabaseError is not subclass of errors.Error")
def test_DataError(self):
"""Interface exports the DataError-exception"""
self.assertTrue(issubclass(myconn.errors.DataError,
myconn.errors.DatabaseError),
"DataError is not subclass of errors.DatabaseError")
def test_OperationalError(self):
"""Interface exports the OperationalError-exception"""
self.assertTrue(
issubclass(myconn.errors.OperationalError,
myconn.errors.DatabaseError),
"OperationalError is not subclass of errors.DatabaseError")
def test_IntegrityError(self):
"""Interface exports the IntegrityError-exception"""
self.assertTrue(
issubclass(myconn.errors.IntegrityError,
myconn.errors.DatabaseError),
"IntegrityError is not subclass of errors.DatabaseError")
def test_InternalError(self):
"""Interface exports the InternalError-exception"""
self.assertTrue(issubclass(myconn.errors.InternalError,
myconn.errors.DatabaseError),
"InternalError is not subclass of errors.DatabaseError")
def test_ProgrammingError(self):
"""Interface exports the ProgrammingError-exception"""
self.assertTrue(
issubclass(myconn.errors.ProgrammingError,
myconn.errors.DatabaseError),
"ProgrammingError is not subclass of errors.DatabaseError")
def test_NotSupportedError(self):
"""Interface exports the NotSupportedError-exception"""
self.assertTrue(
issubclass(myconn.errors.NotSupportedError,
myconn.errors.DatabaseError),
"NotSupportedError is not subclass of errors.DatabaseError")
class PEP249ConnectionTests(PEP249Base):
def test_close(self):
"""Connection object has close()-method"""
self.assertTrue(hasattr(self.cnx, 'close'),
"Interface connection has no close()-method")
self.assertTrue(
inspect.ismethod(self.cnx.close),
"Interface connection defines connect, but is not a method")
def test_commit(self):
"""Connection object has commit()-method"""
self.assertTrue(hasattr(self.cnx, 'commit'),
"Interface connection has no commit()-method")
self.assertTrue(
inspect.ismethod(self.cnx.commit),
"Interface connection defines commit, but is not a method")
def test_rollback(self):
"""Connection object has rollback()-method"""
self.assertTrue(hasattr(self.cnx, 'rollback'),
"Interface connection has no rollback()-method")
self.assertTrue(
inspect.ismethod(self.cnx.rollback),
"Interface connection defines rollback, but is not a method")
def test_cursor(self):
"""Connection object has cursor()-method"""
self.assertTrue(hasattr(self.cnx, 'cursor'),
"Interface connection has no cursor()-method")
self.assertTrue(
inspect.ismethod(self.cnx.cursor),
"Interface connection defines cursor, but is not a method")
self.assertTrue(
isinstance(self.cnx.cursor(), myconn.cursor.MySQLCursor),
"Interface connection cursor()-method does not return a cursor")
class PEP249CursorTests(PEP249Base):
def setUp(self):
self.cnx = self.db_connect()
self.cur = self.cnx.cursor()
def test_description(self):
"""Cursor object has description-attribute"""
self.assertTrue(hasattr(self.cur, 'description'),
"Cursor object has no description-attribute")
self.assertEqual(None, self.cur.description,
"Cursor object's description should default ot None")
def test_rowcount(self):
"""Cursor object has rowcount-attribute"""
self.assertTrue(hasattr(self.cur, 'rowcount'),
"Cursor object has no rowcount-attribute")
self.assertEqual(-1, self.cur.rowcount,
"Cursor object's rowcount should default to -1")
def test_lastrowid(self):
"""Cursor object has lastrowid-attribute"""
self.assertTrue(hasattr(self.cur, 'lastrowid'),
"Cursor object has no lastrowid-attribute")
self.assertEqual(None, self.cur.lastrowid,
"Cursor object's lastrowid should default to None")
def test_callproc(self):
"""Cursor object has callproc()-method"""
self.assertTrue(hasattr(self.cur, 'callproc'),
"Cursor object has no callproc()-method")
self.assertTrue(inspect.ismethod(self.cur.callproc),
"Cursor object defines callproc, but is not a method")
def test_close(self):
"""Cursor object has close()-method"""
self.assertTrue(hasattr(self.cur, 'close'),
"Cursor object has no close()-method")
self.assertTrue(inspect.ismethod(self.cur.close),
"Cursor object defines close, but is not a method")
def test_execute(self):
"""Cursor object has execute()-method"""
self.assertTrue(hasattr(self.cur, 'execute'),
"Cursor object has no execute()-method")
self.assertTrue(inspect.ismethod(self.cur.execute),
"Cursor object defines execute, but is not a method")
def test_executemany(self):
"""Cursor object has executemany()-method"""
self.assertTrue(hasattr(self.cur, 'executemany'),
"Cursor object has no executemany()-method")
self.assertTrue(
inspect.ismethod(self.cur.executemany),
"Cursor object defines executemany, but is not a method")
def test_fetchone(self):
"""Cursor object has fetchone()-method"""
self.assertTrue(hasattr(self.cur, 'fetchone'),
"Cursor object has no fetchone()-method")
self.assertTrue(inspect.ismethod(self.cur.fetchone),
"Cursor object defines fetchone, but is not a method")
def test_fetchmany(self):
"""Cursor object has fetchmany()-method"""
self.assertTrue(hasattr(self.cur, 'execute'),
"Cursor object has no fetchmany()-method")
self.assertTrue(inspect.ismethod(self.cur.fetchmany),
"Cursor object defines fetchmany, but is not a method")
def test_fetchall(self):
"""Cursor object has fetchall()-method"""
self.assertTrue(hasattr(self.cur, 'fetchall'),
"Cursor object has no fetchall()-method")
self.assertTrue(inspect.ismethod(self.cur.fetchall),
"Cursor object defines fetchall, but is not a method")
def test_nextset(self):
"""Cursor object has nextset()-method"""
self.assertTrue(hasattr(self.cur, 'nextset'),
"Cursor object has no nextset()-method")
self.assertTrue(inspect.ismethod(self.cur.nextset),
"Cursor object defines nextset, but is not a method")
def test_arraysize(self):
"""Cursor object has arraysize-attribute"""
self.assertTrue(hasattr(self.cur, 'arraysize'),
"Cursor object has no arraysize-attribute")
self.assertEqual(1, self.cur.arraysize,
"Cursor object's arraysize should default to 1")
def test_setinputsizes(self):
"""Cursor object has setinputsizes()-method"""
self.assertTrue(hasattr(self.cur, 'setinputsizes'),
"Cursor object has no setinputsizes()-method")
self.assertTrue(inspect.ismethod(self.cur.setinputsizes),
"Cursor object's setinputsizes should default to 1")
def test_setoutputsize(self):
"""Cursor object has setoutputsize()-method"""
self.assertTrue(hasattr(self.cur, 'setoutputsize'),
"Cursor object has no setoutputsize()-method")
self.assertTrue(inspect.ismethod(self.cur.setoutputsize),
"Cursor object's setoutputsize should default to 1")
def _isolation_setup(self, drop, create):
cursor = self.cnx.cursor()
try:
cursor.execute(drop)
cursor.execute(create)
except myconn.errors.Error as err:
self.fail("Failed setting up test table; {0}".format(err))
cursor.close()
def _isolation_connection_equal(self, cnx1, cnx2):
cid1 = self.get_connection_id(cnx1)
cid2 = self.get_connection_id(cnx2)
return (cid1 == cid2)
def _isolation_cleanup(self, drop):
cursor = self.cnx.cursor()
try:
cursor.execute(drop)
except myconn.errors.Error as err:
self.fail("Failed cleaning up; {0}".format(err))
cursor.close()
def _isolation_test(self, cnx1, cnx2, engine='MyISAM'):
cur1 = cnx1.cursor()
cur2 = cnx2.cursor()
data = (1, 'myconnpy')
tbl = 'myconnpy_cursor_isolation'
stmt_create = (
"CREATE TABLE {table} "
"(col1 INT, col2 VARCHAR(30), PRIMARY KEY (col1)) "
"ENGINE={engine}"
).format(table=tbl, engine=engine)
stmt_drop = "DROP TABLE IF EXISTS {table}".format(table=tbl)
stmt_insert = (
"INSERT INTO {table} (col1,col2) "
"VALUES (%s,%s)"
).format(table=tbl)
stmt_select = "SELECT col1,col2 FROM {table}".format(table=tbl)
# Setup
cur1.execute("SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ")
cur2.execute("SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ")
self._isolation_setup(stmt_drop, stmt_create)
conn_equal = self._isolation_connection_equal(cur1, cur2)
if cnx1 == cnx2 and not conn_equal:
self.fail("Cursors should have same connection ID")
elif cnx1 != cnx2 and conn_equal:
self.fail("Cursors should have different connection ID")
# Insert data
try:
cur1.execute(stmt_insert, data)
except myconn.errors.Error as err:
self.fail("Failed inserting test data; {0}".format(str(err)))
# Query for data
result = None
try:
cur2.execute(stmt_select)
result = cur2.fetchone()
except myconn.errors.InterfaceError:
pass
except myconn.errors.Error as err:
self.fail("Failed querying for test data; {0}".format(str(err)))
if conn_equal:
self.assertEqual(data, result)
elif not conn_equal and engine.lower() == 'innodb':
self.assertEqual(None, result)
# Clean up
self._isolation_cleanup(stmt_drop)
cur1.close()
cur2.close()
def test_isolation1(self):
"""Cursor isolation between 2 cursor on same connection"""
self._isolation_test(self.cnx, self.cnx, 'MyISAM')
def test_isolation2(self):
"""Cursor isolation with 2 cursors, different connections, trans."""
db2 = self.db_connect()
if tests.have_engine(db2, 'InnoDB'):
self._isolation_test(self.cnx, db2, 'InnoDB')
class PEP249TypeObjConstructorsTests(PEP249Base):
def test_Date(self):
"""Interface exports Date"""
exp = datetime.date(1977, 6, 14)
self.assertEqual(myconn.Date(1977, 6, 14), exp,
"Interface Date should return a datetime.date")
def test_Time(self):
"""Interface exports Time"""
exp = datetime.time(23, 56, 13)
self.assertEqual(myconn.Time(23, 56, 13), exp,
"Interface Time should return a datetime.time")
def test_Timestamp(self):
"""Interface exports Timestamp"""
adate = (1977, 6, 14, 21, 54, 23)
exp = datetime.datetime(*adate)
self.assertEqual(
myconn.Timestamp(*adate), exp,
"Interface Timestamp should return a datetime.datetime")
def test_DateFromTicks(self):
"""Interface exports DateFromTicks"""
ticks = 1
exp = datetime.date(*time.localtime(ticks)[:3])
self.assertEqual(
myconn.DateFromTicks(ticks), exp,
"Interface DateFromTicks should return a datetime.date")
def test_TimeFromTicks(self):
"""Interface exports TimeFromTicks"""
ticks = 1
exp = datetime.time(*time.localtime(ticks)[3:6])
self.assertEqual(
myconn.TimeFromTicks(ticks), exp,
"Interface TimeFromTicks should return a datetime.time")
def test_TimestampFromTicks(self):
"""Interface exports TimestampFromTicks"""
ticks = 1
exp = datetime.datetime(*time.localtime(ticks)[:6])
self.assertEqual(
myconn.TimestampFromTicks(ticks), exp,
"Interface TimestampFromTicks should return a datetime.datetime")
def test_Binary(self):
"""Interface exports Binary"""
exp = r'\u82b1'.encode('utf-8')
self.assertEqual(
myconn.Binary(r'\u82b1'.encode('utf-8')), exp,
"Interface Binary should return a str")
def test_STRING(self):
"""Interface exports STRING"""
self.assertTrue(hasattr(myconn, 'STRING'))
self.assertTrue(
isinstance(myconn.STRING, myconn.dbapi._DBAPITypeObject),
"Interface STRING should return a _DBAPITypeObject")
def test_BINARY(self):
"""Interface exports BINARY"""
self.assertTrue(hasattr(myconn, 'BINARY'))
self.assertTrue(
isinstance(myconn.BINARY, myconn.dbapi._DBAPITypeObject),
"Interface BINARY should return a _DBAPITypeObject")
def test_NUMBER(self):
"""Interface exports NUMBER"""
self.assertTrue(hasattr(myconn, 'NUMBER'))
self.assertTrue(
isinstance(myconn.NUMBER, myconn.dbapi._DBAPITypeObject),
"Interface NUMBER should return a _DBAPITypeObject")
def test_DATETIME(self):
"""Interface exports DATETIME"""
self.assertTrue(hasattr(myconn, 'DATETIME'))
self.assertTrue(
isinstance(myconn.DATETIME, myconn.dbapi._DBAPITypeObject),
"Interface DATETIME should return a _DBAPITypeObject")
def test_ROWID(self):
"""Interface exports ROWID"""
self.assertTrue(hasattr(myconn, 'ROWID'))
self.assertTrue(
isinstance(myconn.ROWID, myconn.dbapi._DBAPITypeObject),
"Interface ROWID should return a _DBAPITypeObject")

View File

@@ -1,361 +0,0 @@
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2013, 2014, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unittests for mysql.connector.pooling
"""
import uuid
try:
from Queue import Queue
except ImportError:
# Python 3
from queue import Queue
import tests
import mysql.connector
from mysql.connector import errors
from mysql.connector.connection import MySQLConnection
from mysql.connector import pooling
class PoolingTests(tests.MySQLConnectorTests):
def tearDown(self):
mysql.connector._CONNECTION_POOLS = {}
def test_generate_pool_name(self):
self.assertRaises(errors.PoolError, pooling.generate_pool_name)
config = {'host': 'ham', 'database': 'spam'}
self.assertEqual('ham_spam',
pooling.generate_pool_name(**config))
config = {'database': 'spam', 'port': 3377, 'host': 'example.com'}
self.assertEqual('example.com_3377_spam',
pooling.generate_pool_name(**config))
config = {
'user': 'ham', 'database': 'spam',
'port': 3377, 'host': 'example.com'}
self.assertEqual('example.com_3377_ham_spam',
pooling.generate_pool_name(**config))
class PooledMySQLConnectionTests(tests.MySQLConnectorTests):
def tearDown(self):
mysql.connector._CONNECTION_POOLS = {}
def test___init__(self):
dbconfig = tests.get_mysql_config()
cnxpool = pooling.MySQLConnectionPool(pool_size=1, **dbconfig)
self.assertRaises(TypeError, pooling.PooledMySQLConnection)
cnx = MySQLConnection(**dbconfig)
pcnx = pooling.PooledMySQLConnection(cnxpool, cnx)
self.assertEqual(cnxpool, pcnx._cnx_pool)
self.assertEqual(cnx, pcnx._cnx)
self.assertRaises(AttributeError, pooling.PooledMySQLConnection,
None, None)
self.assertRaises(AttributeError, pooling.PooledMySQLConnection,
cnxpool, None)
def test___getattr__(self):
dbconfig = tests.get_mysql_config()
cnxpool = pooling.MySQLConnectionPool(pool_size=1, pool_name='test')
cnx = MySQLConnection(**dbconfig)
pcnx = pooling.PooledMySQLConnection(cnxpool, cnx)
exp_attrs = {
'_connection_timeout': dbconfig['connection_timeout'],
'_database': dbconfig['database'],
'_host': dbconfig['host'],
'_password': dbconfig['password'],
'_port': dbconfig['port'],
'_unix_socket': dbconfig['unix_socket']
}
for attr, value in exp_attrs.items():
self.assertEqual(
value,
getattr(pcnx, attr),
"Attribute {0} of reference connection not correct".format(
attr))
self.assertEqual(pcnx.connect, cnx.connect)
def test_close(self):
dbconfig = tests.get_mysql_config()
cnxpool = pooling.MySQLConnectionPool(pool_size=1, **dbconfig)
cnxpool._original_cnx = None
def dummy_add_connection(self, cnx=None):
self._original_cnx = cnx
cnxpool.add_connection = dummy_add_connection.__get__(
cnxpool, pooling.MySQLConnectionPool)
pcnx = pooling.PooledMySQLConnection(cnxpool,
MySQLConnection(**dbconfig))
cnx = pcnx._cnx
pcnx.close()
self.assertEqual(cnx, cnxpool._original_cnx)
def test_config(self):
dbconfig = tests.get_mysql_config()
cnxpool = pooling.MySQLConnectionPool(pool_size=1, **dbconfig)
cnx = cnxpool.get_connection()
self.assertRaises(errors.PoolError, cnx.config, user='spam')
class MySQLConnectionPoolTests(tests.MySQLConnectorTests):
def tearDown(self):
mysql.connector._CONNECTION_POOLS = {}
def test___init__(self):
dbconfig = tests.get_mysql_config()
self.assertRaises(errors.PoolError, pooling.MySQLConnectionPool)
self.assertRaises(AttributeError, pooling.MySQLConnectionPool,
pool_name='test',
pool_size=-1)
self.assertRaises(AttributeError, pooling.MySQLConnectionPool,
pool_name='test',
pool_size=0)
self.assertRaises(AttributeError, pooling.MySQLConnectionPool,
pool_name='test',
pool_size=(pooling.CNX_POOL_MAXSIZE + 1))
cnxpool = pooling.MySQLConnectionPool(pool_name='test')
self.assertEqual(5, cnxpool._pool_size)
self.assertEqual('test', cnxpool._pool_name)
self.assertEqual({}, cnxpool._cnx_config)
self.assertTrue(isinstance(cnxpool._cnx_queue, Queue))
self.assertTrue(isinstance(cnxpool._config_version, uuid.UUID))
self.assertTrue(True, cnxpool._reset_session)
cnxpool = pooling.MySQLConnectionPool(pool_size=10, pool_name='test')
self.assertEqual(10, cnxpool._pool_size)
cnxpool = pooling.MySQLConnectionPool(pool_size=10, **dbconfig)
self.assertEqual(dbconfig, cnxpool._cnx_config,
"Connection configuration not saved correctly")
self.assertEqual(10, cnxpool._cnx_queue.qsize())
self.assertTrue(isinstance(cnxpool._config_version, uuid.UUID))
cnxpool = pooling.MySQLConnectionPool(pool_size=1, pool_name='test',
pool_reset_session=False)
self.assertFalse(cnxpool._reset_session)
def test_pool_name(self):
"""Test MySQLConnectionPool.pool_name property"""
pool_name = 'ham'
cnxpool = pooling.MySQLConnectionPool(pool_name=pool_name)
self.assertEqual(pool_name, cnxpool.pool_name)
def test_reset_session(self):
"""Test MySQLConnectionPool.reset_session property"""
cnxpool = pooling.MySQLConnectionPool(pool_name='test',
pool_reset_session=False)
self.assertFalse(cnxpool.reset_session)
cnxpool._reset_session = True
self.assertTrue(cnxpool.reset_session)
def test_pool_size(self):
"""Test MySQLConnectionPool.pool_size property"""
pool_size = 4
cnxpool = pooling.MySQLConnectionPool(pool_name='test',
pool_size=pool_size)
self.assertEqual(pool_size, cnxpool.pool_size)
def test_reset_session(self):
"""Test MySQLConnectionPool.reset_session property"""
cnxpool = pooling.MySQLConnectionPool(pool_name='test',
pool_reset_session=False)
self.assertFalse(cnxpool.reset_session)
cnxpool._reset_session = True
self.assertTrue(cnxpool.reset_session)
def test__set_pool_size(self):
cnxpool = pooling.MySQLConnectionPool(pool_name='test')
self.assertRaises(AttributeError, cnxpool._set_pool_size, -1)
self.assertRaises(AttributeError, cnxpool._set_pool_size, 0)
self.assertRaises(AttributeError, cnxpool._set_pool_size,
pooling.CNX_POOL_MAXSIZE + 1)
cnxpool._set_pool_size(pooling.CNX_POOL_MAXSIZE - 1)
self.assertEqual(pooling.CNX_POOL_MAXSIZE - 1, cnxpool._pool_size)
def test__set_pool_name(self):
cnxpool = pooling.MySQLConnectionPool(pool_name='test')
self.assertRaises(AttributeError, cnxpool._set_pool_name, 'pool name')
self.assertRaises(AttributeError, cnxpool._set_pool_name, 'pool%%name')
self.assertRaises(AttributeError, cnxpool._set_pool_name,
'long_pool_name' * pooling.CNX_POOL_MAXNAMESIZE)
def test_add_connection(self):
cnxpool = pooling.MySQLConnectionPool(pool_name='test')
self.assertRaises(errors.PoolError, cnxpool.add_connection)
dbconfig = tests.get_mysql_config()
cnxpool = pooling.MySQLConnectionPool(pool_size=2, pool_name='test')
cnxpool.set_config(**dbconfig)
cnxpool.add_connection()
pcnx = pooling.PooledMySQLConnection(
cnxpool,
cnxpool._cnx_queue.get(block=False))
self.assertTrue(isinstance(pcnx._cnx, MySQLConnection))
self.assertEqual(cnxpool, pcnx._cnx_pool)
self.assertEqual(cnxpool._config_version,
pcnx._cnx._pool_config_version)
cnx = pcnx._cnx
pcnx.close()
# We should get the same connectoin back
self.assertEqual(cnx, cnxpool._cnx_queue.get(block=False))
cnxpool.add_connection(cnx)
# reach max connections
cnxpool.add_connection()
self.assertRaises(errors.PoolError, cnxpool.add_connection)
# fail connecting
cnxpool._remove_connections()
cnxpool._cnx_config['port'] = 9999999
cnxpool._cnx_config['unix_socket'] = '/ham/spam/foobar.socket'
self.assertRaises(errors.InterfaceError, cnxpool.add_connection)
self.assertRaises(errors.PoolError, cnxpool.add_connection, cnx=str)
def test_set_config(self):
dbconfig = tests.get_mysql_config()
cnxpool = pooling.MySQLConnectionPool(pool_name='test')
# No configuration changes
config_version = cnxpool._config_version
cnxpool.set_config()
self.assertEqual(config_version, cnxpool._config_version)
self.assertEqual({}, cnxpool._cnx_config)
# Valid configuration changes
config_version = cnxpool._config_version
cnxpool.set_config(**dbconfig)
self.assertEqual(dbconfig, cnxpool._cnx_config)
self.assertNotEqual(config_version, cnxpool._config_version)
# Invalid configuration changes
config_version = cnxpool._config_version
wrong_dbconfig = dbconfig.copy()
wrong_dbconfig['spam'] = 'ham'
self.assertRaises(errors.PoolError, cnxpool.set_config,
**wrong_dbconfig)
self.assertEqual(dbconfig, cnxpool._cnx_config)
self.assertEqual(config_version, cnxpool._config_version)
def test_get_connection(self):
dbconfig = tests.get_mysql_config()
cnxpool = pooling.MySQLConnectionPool(pool_size=2, pool_name='test')
self.assertRaises(errors.PoolError, cnxpool.get_connection)
cnxpool = pooling.MySQLConnectionPool(pool_size=1, **dbconfig)
# Get connection from pool
pcnx = cnxpool.get_connection()
self.assertTrue(isinstance(pcnx, pooling.PooledMySQLConnection))
self.assertRaises(errors.PoolError, cnxpool.get_connection)
self.assertEqual(pcnx._cnx._pool_config_version,
cnxpool._config_version)
prev_config_version = pcnx._pool_config_version
prev_thread_id = pcnx.connection_id
pcnx.close()
# Change configuration
config_version = cnxpool._config_version
cnxpool.set_config(autocommit=True)
self.assertNotEqual(config_version, cnxpool._config_version)
pcnx = cnxpool.get_connection()
self.assertNotEqual(
pcnx._cnx._pool_config_version, prev_config_version)
self.assertNotEqual(prev_thread_id, pcnx.connection_id)
self.assertEqual(1, pcnx.autocommit)
pcnx.close()
def test__remove_connections(self):
dbconfig = tests.get_mysql_config()
cnxpool = pooling.MySQLConnectionPool(
pool_size=2, pool_name='test', **dbconfig)
pcnx = cnxpool.get_connection()
self.assertEqual(1, cnxpool._remove_connections())
pcnx.close()
self.assertEqual(1, cnxpool._remove_connections())
self.assertEqual(0, cnxpool._remove_connections())
self.assertRaises(errors.PoolError, cnxpool.get_connection)
class ModuleConnectorPoolingTests(tests.MySQLConnectorTests):
"""Testing MySQL Connector module pooling functionality"""
def tearDown(self):
mysql.connector._CONNECTION_POOLS = {}
def test__connection_pools(self):
self.assertEqual(mysql.connector._CONNECTION_POOLS, {})
def test__get_pooled_connection(self):
dbconfig = tests.get_mysql_config()
mysql.connector._CONNECTION_POOLS.update({'spam': 'ham'})
self.assertRaises(errors.InterfaceError,
mysql.connector.connect, pool_name='spam')
mysql.connector._CONNECTION_POOLS = {}
mysql.connector.connect(pool_name='ham', **dbconfig)
self.assertTrue('ham' in mysql.connector._CONNECTION_POOLS)
cnxpool = mysql.connector._CONNECTION_POOLS['ham']
self.assertTrue(isinstance(cnxpool,
pooling.MySQLConnectionPool))
self.assertEqual('ham', cnxpool.pool_name)
mysql.connector.connect(pool_size=5, **dbconfig)
pool_name = pooling.generate_pool_name(**dbconfig)
self.assertTrue(pool_name in mysql.connector._CONNECTION_POOLS)
def test_connect(self):
dbconfig = tests.get_mysql_config()
cnx = mysql.connector.connect(pool_size=1, pool_name='ham', **dbconfig)
exp = cnx.connection_id
cnx.close()
self.assertEqual(
exp,
mysql.connector._get_pooled_connection(
pool_name='ham').connection_id
)

View File

@@ -1,574 +0,0 @@
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2009, 2016, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unittests for mysql.connector.protocol
"""
import struct
import datetime
import decimal
import tests
from mysql.connector import (protocol, errors)
from mysql.connector.constants import (ClientFlag, FieldType, FieldFlag)
OK_PACKET = bytearray(b'\x07\x00\x00\x01\x00\x01\x00\x00\x00\x01\x00')
OK_PACKET_RESULT = {
'insert_id': 0,
'affected_rows': 1,
'field_count': 0,
'warning_count': 1,
'status_flag': 0
}
ERR_PACKET = bytearray(
b'\x47\x00\x00\x02\xff\x15\x04\x23\x32\x38\x30\x30\x30'
b'\x41\x63\x63\x65\x73\x73\x20\x64\x65\x6e\x69\x65\x64'
b'\x20\x66\x6f\x72\x20\x75\x73\x65\x72\x20\x27\x68\x61'
b'\x6d\x27\x40\x27\x6c\x6f\x63\x61\x6c\x68\x6f\x73\x74'
b'\x27\x20\x28\x75\x73\x69\x6e\x67\x20\x70\x61\x73\x73'
b'\x77\x6f\x72\x64\x3a\x20\x59\x45\x53\x29'
)
EOF_PACKET = bytearray(b'\x01\x00\x00\x00\xfe\x00\x00\x00\x00')
EOF_PACKET_RESULT = {'status_flag': 0, 'warning_count': 0}
SEED = bytearray(
b'\x3b\x55\x78\x7d\x2c\x5f\x7c\x72\x49\x52'
b'\x3f\x28\x47\x6f\x77\x28\x5f\x28\x46\x69'
)
class MySQLProtocolTests(tests.MySQLConnectorTests):
def setUp(self):
self._protocol = protocol.MySQLProtocol()
def test_make_auth(self):
"""Make a MySQL authentication packet"""
exp = {
'allset': bytearray(
b'\x8d\xa2\x03\x00\x00\x00\x00\x40'
b'\x21\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x68\x61\x6d\x00\x14\x3a\x07\x66\xba\xba\x01\xce'
b'\xbe\x55\xe6\x29\x88\xaa\xae\xdb\x00\xb3\x4d\x91'
b'\x5b\x74\x65\x73\x74\x00'),
'nopass': bytearray(
b'\x8d\xa2\x03\x00\x00\x00\x00\x40'
b'\x21\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x68\x61\x6d\x00\x00\x74\x65\x73\x74\x00'),
'nouser': bytearray(
b'\x8d\xa2\x03\x00\x00\x00\x00\x40'
b'\x21\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x14\x3a\x07\x66\xba\xba\x01\xce'
b'\xbe\x55\xe6\x29\x88\xaa\xae\xdb\x00\xb3\x4d\x91'
b'\x5b\x74\x65\x73\x74\x00'),
'nodb': bytearray(
b'\x8d\xa2\x03\x00\x00\x00\x00\x40'
b'\x21\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x68\x61\x6d\x00\x14\x3a\x07\x66\xba\xba\x01\xce'
b'\xbe\x55\xe6\x29\x88\xaa\xae\xdb\x00\xb3\x4d\x91'
b'\x5b\x00'),
}
flags = ClientFlag.get_default()
kwargs = {
'handshake': None,
'username': 'ham',
'password': 'spam',
'database': 'test',
'charset': 33,
'client_flags': flags
}
self.assertRaises(errors.ProgrammingError,
self._protocol.make_auth, **kwargs)
kwargs['handshake'] = {'auth_data': SEED}
self.assertRaises(errors.ProgrammingError,
self._protocol.make_auth, **kwargs)
kwargs['handshake'] = {
'auth_data': SEED,
'auth_plugin': 'mysql_native_password'
}
res = self._protocol.make_auth(**kwargs)
self.assertEqual(exp['allset'], res)
kwargs['password'] = None
res = self._protocol.make_auth(**kwargs)
self.assertEqual(exp['nopass'], res)
kwargs['password'] = 'spam'
kwargs['database'] = None
res = self._protocol.make_auth(**kwargs)
self.assertEqual(exp['nodb'], res)
kwargs['username'] = None
kwargs['database'] = 'test'
res = self._protocol.make_auth(**kwargs)
self.assertEqual(exp['nouser'], res)
def test_make_auth_ssl(self):
"""Make a SSL authentication packet"""
cases = [
({},
b'\x00\x00\x00\x00\x00\x00\x00\x40\x21\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00'),
({'charset': 8},
b'\x00\x00\x00\x00\x00\x00\x00\x40\x08\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00'),
({'client_flags': 240141},
b'\x0d\xaa\x03\x00\x00\x00\x00\x40\x21\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00'),
({'charset': 8, 'client_flags': 240141,
'max_allowed_packet': 2147483648},
b'\x0d\xaa\x03\x00\x00\x00\x00\x80\x08\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x00\x00'),
]
for kwargs, exp in cases:
self.assertEqual(exp, self._protocol.make_auth_ssl(**kwargs))
def test_make_command(self):
"""Make a generic MySQL command packet"""
exp = bytearray(b'\x01\x68\x61\x6d')
arg = 'ham'.encode('utf-8')
res = self._protocol.make_command(1, arg)
self.assertEqual(exp, res)
res = self._protocol.make_command(1, argument=arg)
self.assertEqual(exp, res)
exp = b'\x03'
res = self._protocol.make_command(3)
self.assertEqual(exp, res)
def test_make_changeuser(self):
"""Make a change user MySQL packet"""
exp = {
'allset': bytearray(
b'\x11\x68\x61\x6d\x00\x14\x3a\x07'
b'\x66\xba\xba\x01\xce\xbe\x55\xe6\x29\x88\xaa\xae'
b'\xdb\x00\xb3\x4d\x91\x5b\x74\x65\x73\x74\x00\x08'
b'\x00'),
'nopass': bytearray(
b'\x11\x68\x61\x6d\x00\x00\x74\x65'
b'\x73\x74\x00\x08\x00'),
}
kwargs = {
'handshake': None,
'username': 'ham',
'password': 'spam',
'database': 'test',
'charset': 8,
'client_flags': ClientFlag.get_default()
}
self.assertRaises(errors.ProgrammingError,
self._protocol.make_change_user, **kwargs)
kwargs['handshake'] = {'auth_data': SEED}
self.assertRaises(errors.ProgrammingError,
self._protocol.make_change_user, **kwargs)
kwargs['handshake'] = {
'auth_data': SEED,
'auth_plugin': 'mysql_native_password'
}
res = self._protocol.make_change_user(**kwargs)
self.assertEqual(exp['allset'], res)
kwargs['password'] = None
res = self._protocol.make_change_user(**kwargs)
self.assertEqual(exp['nopass'], res)
def test_parse_handshake(self):
"""Parse handshake-packet sent by MySQL"""
handshake = bytearray(
b'\x47\x00\x00\x00\x0a\x35\x2e\x30\x2e\x33\x30\x2d'
b'\x65\x6e\x74\x65\x72\x70\x72\x69\x73\x65\x2d\x67'
b'\x70\x6c\x2d\x6c\x6f\x67\x00\x09\x01\x00\x00\x68'
b'\x34\x69\x36\x6f\x50\x21\x4f\x00\x2c\xa2\x08\x02'
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00\x00\x4c\x6e\x67\x39\x26\x50\x44\x40\x57\x72'
b'\x59\x48\x00'
)
exp = {
'protocol': 10,
'server_version_original': '5.0.30-enterprise-gpl-log',
'charset': 8,
'server_threadid': 265,
'capabilities': 41516,
'server_status': 2,
'auth_data': b'h4i6oP!OLng9&PD@WrYH',
'auth_plugin': 'mysql_native_password',
}
res = self._protocol.parse_handshake(handshake)
self.assertEqual(exp, res)
# Test when end byte \x00 is not present for server 5.5.8
handshake = handshake[:-1]
res = self._protocol.parse_handshake(handshake)
self.assertEqual(exp, res)
def test_parse_ok(self):
"""Parse OK-packet sent by MySQL"""
res = self._protocol.parse_ok(OK_PACKET)
self.assertEqual(OK_PACKET_RESULT, res)
okpkt = OK_PACKET + b'\x04spam'
exp = OK_PACKET_RESULT.copy()
exp['info_msg'] = 'spam'
res = self._protocol.parse_ok(okpkt)
self.assertEqual(exp, res)
def test_parse_column_count(self):
"""Parse the number of columns"""
packet = bytearray(b'\x01\x00\x00\x01\x03')
res = self._protocol.parse_column_count(packet)
self.assertEqual(3, res)
packet = bytearray(b'\x01\x00')
self.assertRaises(errors.InterfaceError,
self._protocol.parse_column_count, packet)
def test_parse_column(self):
"""Parse field-packet sent by MySQL"""
column_packet = bytearray(
b'\x1a\x00\x00\x02\x03\x64\x65\x66\x00\x00\x00\x04'
b'\x53\x70\x61\x6d\x00\x0c\x21\x00\x09\x00\x00\x00'
b'\xfd\x01\x00\x1f\x00\x00')
exp = ('Spam', 253, None, None, None, None, 0, 1)
res = self._protocol.parse_column(column_packet)
self.assertEqual(exp, res)
def test_parse_eof(self):
"""Parse EOF-packet sent by MySQL"""
res = self._protocol.parse_eof(EOF_PACKET)
self.assertEqual(EOF_PACKET_RESULT, res)
def test_read_text_result(self):
# Tested by MySQLConnectionTests.test_get_rows() and .test_get_row()
pass
def test_parse_binary_prepare_ok(self):
"""Parse Prepare OK packet"""
cases = [
# SELECT CONCAT(?, ?) AS c1
(bytearray(b'\x0c\x00\x00\x01'
b'\x00\x01\x00\x00\x00\x01\x00\x02\x00\x00\x00\x00'),
{'num_params': 2,
'statement_id': 1,
'warning_count': 0,
'num_columns': 1
}
),
# DO 1
(bytearray(b'\x0c\x00\x00\x01'
b'\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'),
{'num_params': 0,
'statement_id': 1,
'warning_count': 0,
'num_columns': 0
}
),
]
for packet, exp in cases:
self.assertEqual(exp,
self._protocol.parse_binary_prepare_ok(packet))
def test__parse_binary_integer(self):
"""Parse an integer from a binary packet"""
# Case = Expected value; pack format; field type; field flag
cases = [
(-128, 'b', FieldType.TINY, 0),
(-32768, 'h', FieldType.SHORT, 0),
(-2147483648, 'i', FieldType.LONG, 0),
(-9999999999, 'q', FieldType.LONGLONG, 0),
(255, 'B', FieldType.TINY, FieldFlag.UNSIGNED),
(65535, 'H', FieldType.SHORT, FieldFlag.UNSIGNED),
(4294967295, 'I', FieldType.LONG, FieldFlag.UNSIGNED),
(9999999999, 'Q', FieldType.LONGLONG, FieldFlag.UNSIGNED),
]
field_info = [None] * 8
field_info[0] = 'c1'
for exp, fmt, field_type, flag in cases:
field_info[1] = field_type
field_info[7] = flag
data = struct.pack(fmt, exp) + b'\x00\x00'
res = self._protocol._parse_binary_integer(data, field_info)
self.assertEqual((b'\x00\x00', exp), res,
"Failed parsing binary integer '{0}'".format(exp))
def test__parse_binary_float(self):
"""Parse a float/double from a binary packet"""
# Case = Expected value; data; field type
cases = [
(-3.14159, bytearray(b'\x6e\x86\x1b\xf0\xf9\x21\x09\xc0'),
FieldType.DOUBLE),
(3.14159, bytearray(b'\x6e\x86\x1b\xf0\xf9\x21\x09\x40'),
FieldType.DOUBLE),
(-3.14, bytearray(b'\xc3\xf5\x48\xc0'), FieldType.FLOAT),
(3.14, bytearray(b'\xc3\xf5\x48\x40'), FieldType.FLOAT),
]
field_info = [None] * 8
field_info[0] = 'c1'
for exp, data, field_type in cases:
field_info[1] = field_type
res = self._protocol._parse_binary_float(data + b'\x00\x00',
field_info)
self.assertEqual(bytearray(b'\x00\x00'), res[0],
"Failed parsing binary float '{0}'".format(exp))
self.assertAlmostEqual(
exp, res[1], places=5,
msg="Failed parsing binary float '{0}'".format(exp))
def test__parse_binary_timestamp(self):
"""Parse a timestamp from a binary packet"""
# Case = Expected value; data
cases = [
(datetime.date(1977, 6, 14), bytearray(b'\x04\xb9\x07\x06\x0e')),
(datetime.datetime(1977, 6, 14, 21, 33, 14),
bytearray(b'\x07\xb9\x07\x06\x0e\x15\x21\x0e')),
(datetime.datetime(1977, 6, 14, 21, 33, 14, 345),
bytearray(b'\x0b\xb9\x07\x06\x0e\x15\x21\x0e\x59\x01\x00\x00'))
]
for exp, data in cases:
res = self._protocol._parse_binary_timestamp(data + b'\x00\x00',
None)
self.assertEqual((b'\x00\x00', exp), res,
"Failed parsing timestamp '{0}'".format(exp))
def test__parse_binary_time(self):
"""Parse a time value from a binary packet"""
cases = [
(datetime.timedelta(0, 44130),
bytearray(b'\x08\x00\x00\x00\x00\x00\x0c\x0f\x1e')),
(datetime.timedelta(14, 15330),
bytearray(b'\x08\x00\x0e\x00\x00\x00\x04\x0f\x1e')),
(datetime.timedelta(-14, 15330),
bytearray(b'\x08\x01\x0e\x00\x00\x00\x04\x0f\x1e')),
(datetime.timedelta(10, 58530, 230000),
bytearray(b'\x0c\x00\x0a\x00\x00\x00'
b'\x10\x0f\x1e\x70\x82\x03\x00')),
]
for exp, data in cases:
res = self._protocol._parse_binary_time(data + b'\x00\x00', None)
self.assertEqual((bytearray(b'\x00\x00'), exp), res,
"Failed parsing time '{0}'".format(exp))
def test__parse_binary_values(self):
"""Parse values from a binary result packet"""
# The packet in this test is result of the following query:
# SELECT 'abc' AS aStr,"
# "3.14 AS aFloat,"
# "-3.14159 AS aDouble, "
# "MAKEDATE(2003, 31) AS aDate, "
# "TIMESTAMP('1977-06-14', '21:33:14') AS aDateTime, "
# "MAKETIME(256,15,30.23) AS aTime, "
# "NULL AS aNull"
#
fields = [('aStr', 253, None, None, None, None, 0, 1),
('aFloat', 246, None, None, None, None, 0, 129),
('aDouble', 246, None, None, None, None, 0, 129),
('aDate', 10, None, None, None, None, 1, 128),
('aDateTime', 12, None, None, None, None, 1, 128),
('aTime', 11, None, None, None, None, 1, 128),
('aNull', 6, None, None, None, None, 1, 128)]
packet = bytearray(b'\x00\x01\x03\x61\x62\x63\x04\x33\x2e\x31\x34\x08'
b'\x2d\x33\x2e\x31\x34\x31\x35\x39\x04\xd3\x07'
b'\x01\x1f\x07\xb9\x07\x06\x0e\x15\x21\x0e\x0c'
b'\x00\x0a\x00\x00\x00\x10\x0f\x1e\x70\x82\x03\x00')
# float/double are returned as DECIMAL by MySQL
exp = (bytearray(b'abc'),
bytearray(b'3.14'),
bytearray(b'-3.14159'),
datetime.date(2003, 1, 31),
datetime.datetime(1977, 6, 14, 21, 33, 14),
datetime.timedelta(10, 58530, 230000),
None)
res = self._protocol._parse_binary_values(fields, packet)
self.assertEqual(exp, res)
def test_read_binary_result(self):
"""Read MySQL binary protocol result"""
def test__prepare_binary_integer(self):
"""Prepare an integer for the MySQL binary protocol"""
# Case = Data; expected value
cases = [
(-128, (struct.pack('b', -128), FieldType.TINY, 0)),
(-32768, (struct.pack('h', -32768), FieldType.SHORT, 0)),
(-2147483648,
(struct.pack('i', -2147483648), FieldType.LONG, 0)),
(-9999999999,
(struct.pack('q', -9999999999), FieldType.LONGLONG, 0)),
(255, (struct.pack('B', 255), FieldType.TINY, 128)),
(65535, (struct.pack('H', 65535), FieldType.SHORT, 128)),
(4294967295,
(struct.pack('I', 4294967295), FieldType.LONG, 128)),
(9999999999,
(struct.pack('Q', 9999999999), FieldType.LONGLONG, 128)),
]
for data, exp in cases:
res = self._protocol._prepare_binary_integer(data)
self.assertEqual(exp, res,
"Failed preparing value '{0}'".format(data))
def test__prepare_binary_timestamp(self):
"""Prepare a timestamp object for the MySQL binary protocol"""
cases = [
(datetime.date(1977, 6, 14),
(bytearray(b'\x04\xb9\x07\x06\x0e'), 10)),
(datetime.datetime(1977, 6, 14),
(bytearray(b'\x07\xb9\x07\x06\x0e\x00\x00\x00'), 12)),
(datetime.datetime(1977, 6, 14, 21, 33, 14),
(bytearray(b'\x07\xb9\x07\x06\x0e\x15\x21\x0e'), 12)),
(datetime.datetime(1977, 6, 14, 21, 33, 14, 345),
(bytearray(b'\x0b\xb9\x07\x06\x0e\x15'
b'\x21\x0e\x59\x01\x00\x00'), 12)),
]
for data, exp in cases:
self.assertEqual(exp,
self._protocol._prepare_binary_timestamp(data),
"Failed preparing value '{0}'".format(data))
# Raise an error
self.assertRaises(ValueError,
self._protocol._prepare_binary_timestamp, 'spam')
def test__prepare_binary_time(self):
"""Prepare a time object for the MySQL binary protocol"""
cases = [
(datetime.timedelta(hours=123, minutes=45, seconds=16),
(bytearray(b'\x08\x00\x05\x00\x00\x00\x03\x2d\x10'), 11)),
(datetime.timedelta(hours=-123, minutes=45, seconds=16),
(bytearray(b'\x08\x01\x06\x00\x00\x00\x15\x2d\x10'), 11)),
(datetime.timedelta(hours=123, minutes=45, seconds=16,
microseconds=345),
(bytearray(b'\x0c\x00\x05\x00\x00\x00\x03'
b'\x2d\x10\x59\x01\x00\x00'), 11)),
(datetime.timedelta(days=123, minutes=45, seconds=16),
(bytearray(b'\x08\x00\x7b\x00\x00\x00\x00\x2d\x10'), 11)),
(datetime.time(14, 53, 36),
(bytearray(b'\x08\x00\x00\x00\x00\x00\x0e\x35\x24'), 11)),
(datetime.time(14, 53, 36, 345),
(bytearray(b'\x0c\x00\x00\x00\x00\x00\x0e'
b'\x35\x24\x59\x01\x00\x00'), 11))
]
for data, exp in cases:
self.assertEqual(exp,
self._protocol._prepare_binary_time(data),
"Failed preparing value '{0}'".format(data))
# Raise an error
self.assertRaises(ValueError,
self._protocol._prepare_binary_time, 'spam')
def test_make_stmt_execute(self):
"""Make a MySQL packet with the STMT_EXECUTE command"""
statement_id = 1
self.assertRaises(errors.InterfaceError,
self._protocol.make_stmt_execute, statement_id,
('ham', 'spam'), (1, 2, 3))
data = ('ham', 'spam')
exp = bytearray(
b'\x01\x00\x00\x00\x00\x01\x00\x00\x00\x00\x01\x0f'
b'\x00\x0f\x00\x03\x68\x61\x6d\x04\x73\x70\x61\x6d'
)
res = self._protocol.make_stmt_execute(statement_id, data, (1, 2))
self.assertEqual(exp, res)
# Testing types
cases = [
('ham',
bytearray(b'\x01\x00\x00\x00\x00\x01\x00\x00\x00\x00'
b'\x01\x0f\x00\x03\x68\x61\x6d')),
(decimal.Decimal('3.14'),
bytearray(b'\x01\x00\x00\x00\x00\x01\x00\x00\x00\x00'
b'\x01\x00\x00\x04\x33\x2e\x31\x34')),
(255,
bytearray(b'\x01\x00\x00\x00\x80\x01\x00'
b'\x00\x00\x00\x01\x01\x80\xff')),
(-128,
bytearray(b'\x01\x00\x00\x00\x00\x01\x00'
b'\x00\x00\x00\x01\x01\x00\x80')),
(datetime.datetime(1977, 6, 14, 21, 20, 30),
bytearray(b'\x01\x00\x00\x00\x00\x01\x00\x00\x00\x00'
b'\x01\x0c\x00\x07\xb9\x07\x06\x0e\x15\x14\x1e')),
(datetime.time(14, 53, 36, 345),
bytearray(b'\x01\x00\x00\x00\x00\x01\x00\x00\x00\x00\x01\x0b\x00'
b'\x0c\x00\x00\x00\x00\x00\x0e\x35\x24\x59\x01\x00\x00')),
(3.14,
bytearray(b'\x01\x00\x00\x00\x00\x01\x00\x00\x00\x00\x01\x05\x00'
b'\x1f\x85\xeb\x51\xb8\x1e\x09\x40')),
]
for data, exp in cases:
res = self._protocol.make_stmt_execute(statement_id, (data,), (1,))
self.assertEqual(
exp, res, "Failed preparing statement with '{0}'".format(data))
# Testing null bitmap
data = (None, None)
exp = bytearray(b'\x01\x00\x00\x00\x00\x01\x00\x00\x00\x03\x01\x06'
b'\x00\x06\x00')
res = self._protocol.make_stmt_execute(statement_id, data, (1, 2))
self.assertEqual(exp, res)
data = (None, 'Ham')
exp = bytearray(
b'\x01\x00\x00\x00\x00\x01\x00\x00\x00\x01\x01\x06\x00\x0f\x00'
b'\x03\x48\x61\x6d'
)
res = self._protocol.make_stmt_execute(statement_id, data, (1, 2))
self.assertEqual(exp, res)
data = ('a',) * 11
exp = bytearray(
b'\x01\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x01'
b'\x0f\x00\x0f\x00\x0f\x00\x0f\x00\x0f\x00\x0f\x00\x0f\x00'
b'\x0f\x00\x0f\x00\x0f\x00\x0f\x00\x01\x61\x01\x61\x01\x61'
b'\x01\x61\x01\x61\x01\x61\x01\x61\x01\x61\x01\x61\x01\x61'
b'\x01\x61'
)
res = self._protocol.make_stmt_execute(statement_id, data, (1,) * 11)
self.assertEqual(exp, res)
# Raise an error passing an unsupported object as parameter value
class UnSupportedObject(object):
pass
data = (UnSupportedObject(), UnSupportedObject())
self.assertRaises(errors.ProgrammingError,
self._protocol.make_stmt_execute,
statement_id, data, (1, 2))

View File

@@ -1,141 +0,0 @@
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2013, 2014, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Unit tests for the setup script of Connector/Python
"""
import sys
import tests
import imp
import setupinfo
class VersionTests(tests.MySQLConnectorTests):
"""Testing the version of Connector/Python"""
def test_version(self):
"""Test validity of version"""
vs = setupinfo.VERSION
self.assertTrue(all(
[isinstance(vs[0], int),
isinstance(vs[1], int),
isinstance(vs[2], int),
isinstance(vs[3], str),
isinstance(vs[4], int)]))
def test___version__(self):
"""Test module __version__ and __version_info__"""
import mysql.connector
self.assertTrue(hasattr(mysql.connector, '__version__'))
self.assertTrue(hasattr(mysql.connector, '__version_info__'))
self.assertTrue(isinstance(mysql.connector.__version__, str))
self.assertTrue(isinstance(mysql.connector.__version_info__, tuple))
self.assertEqual(setupinfo.VERSION_TEXT, mysql.connector.__version__)
self.assertEqual(setupinfo.VERSION, mysql.connector.__version_info__)
class SetupInfoTests(tests.MySQLConnectorTests):
"""Testing meta setup information
We are importing the setupinfo module insite the unit tests
to be able to actually do tests.
"""
def setUp(self):
# we temper with version_info, play safe, keep copy
self._sys_version_info = sys.version_info
def tearDown(self):
# we temper with version_info, play safe, restore copy
sys.version_info = self._sys_version_info
def test_name(self):
"""Test the name of Connector/Python"""
import setupinfo
self.assertEqual('mysql-connector-python', setupinfo.name)
def test_dev_statuses(self):
"""Test the development statuses"""
import setupinfo
exp = {
'a': '3 - Alpha',
'b': '4 - Beta',
'rc': '4 - Beta',
'': '5 - Production/Stable'
}
self.assertEqual(exp, setupinfo.DEVELOPMENT_STATUSES)
def test_package_dir(self):
"""Test the package directory"""
import setupinfo
exp = {
'': 'lib',
}
self.assertEqual(exp, setupinfo.package_dir)
def test_unsupported_python(self):
"""Test if old Python version are unsupported"""
import setupinfo
tmp = sys.version_info
sys.version_info = (3, 0, 0, 'final', 0)
try:
imp.reload(setupinfo)
except RuntimeError:
pass
else:
self.fail("RuntimeError not raised with unsupported Python")
sys.version_info = tmp
def test_version(self):
"""Test the imported version information"""
import setupinfo
ver = setupinfo.VERSION
exp = '{0}.{1}.{2}'.format(*ver[0:3])
self.assertEqual(exp, setupinfo.version)
def test_misc_meta(self):
"""Test miscellaneous data such as URLs"""
import setupinfo
self.assertEqual(
'http://dev.mysql.com/doc/connector-python/en/index.html',
setupinfo.url)
self.assertEqual(
'http://dev.mysql.com/downloads/connector/python/',
setupinfo.download_url)
def test_classifiers(self):
"""Test Trove classifiers"""
import setupinfo
for clsfr in setupinfo.classifiers:
if 'Programming Language :: Python' in clsfr:
ver = clsfr.replace('Programming Language :: Python :: ', '')
if ver not in ('2.6', '2.7', '3', '3.1', '3.2', '3.3'):
self.fail('Unsupported version in classifiers')
if 'Development Status ::' in clsfr:
status = clsfr.replace('Development Status :: ', '')
self.assertEqual(
setupinfo.DEVELOPMENT_STATUSES[setupinfo.VERSION[3]],
status)

Some files were not shown because too many files have changed in this diff Show More