# ____________________________________________________________________________________
#
# Pyomo: Python Optimization Modeling Objects
# Copyright (c) 2008-2026 National Technology and Engineering Solutions of Sandia, LLC
# Under the terms of Contract DE-NA0003525 with National Technology and Engineering
# Solutions of Sandia, LLC, the U.S. Government retains certain rights in this
# software. This software is distributed under the 3-clause BSD License.
# ____________________________________________________________________________________
import os.path
import re
import sys
import shutil
from decimal import Decimal
from pyomo.common.dependencies import attempt_import
from pyomo.dataportal import TableData
from pyomo.dataportal.factory import DataManagerFactory
pyodbc, pyodbc_available = attempt_import('pyodbc')
pypyodbc, pypyodbc_available = attempt_import('pypyodbc')
sqlite3, sqlite3_available = attempt_import('sqlite3')
pymysql, pymysql_available = attempt_import('pymysql')
# format=
# using=
# query=
# user=
# password=
# table=
[docs]
class db_Table(TableData):
[docs]
def __init__(self):
TableData.__init__(self)
self.using = None
[docs]
def open(self):
if self.filename is None:
raise IOError("No data source name was specified")
if self.filename[0] == '"':
self.filename = self.filename[1:-1]
#
# Initialize self.db
#
if not self.using is None:
self.options.using = self.using
self.db = None
if self._data is not None:
self.db = self._data
else:
try:
self.db = self.connect(self.filename, self.options)
except Exception:
raise
[docs]
def read(self):
#
# Get the table from the database
#
if self.db is None:
return
cursor = self.db.cursor()
tmp = []
if self.options.query is None:
if self.options.table is None:
raise IOError("Must specify 'query' or 'table' option!")
self.options.query = 'SELECT * FROM %s' % self.options.table
elif self.options.query[0] in ("'", '"'):
self.options.query = self.options.query[1:-1]
try:
cursor.execute(self.options.query)
rows = cursor.fetchall()
for col in cursor.description:
tmp.append(col[0])
tmp = [tmp]
except sqlite3.OperationalError:
import logging
logging.getLogger('pyomo.core').error(
"""Fatal error reading from an external ODBC data source.
This error was generated outside Pyomo by the Python connector to the
external data source:
%s
for the query:
%s
It is possible that you have an error in your external data file,
the ODBC connector for this data source is not correctly installed,
or that there is a bug in the ODBC connector.
""" % (self.filename, self.options.query)
)
raise
for row in rows:
# print("DATA %s" % str(list(row))) # XXX
ttmp = []
for data in list(row):
if isinstance(data, Decimal):
ttmp.append(float(data))
elif data is None:
ttmp.append('.')
elif isinstance(data, str):
nulidx = data.find('\x00')
if nulidx > -1:
data = data[:nulidx]
ttmp.append(data)
else:
ttmp.append(data)
tmp.append(ttmp)
# print('FINAL %s' % str(tmp)) # XXX
#
# Process data from the table
#
if type(tmp) in (int, float):
if not self.options.param is None:
self._info = ["param", self.options.param.local_name, ":=", tmp]
elif len(self.options.symbol_map) == 1:
self._info = [
"param",
self.options.symbol_map[self.options.symbol_map.keys()[0]],
":=",
tmp,
]
else:
raise IOError(
"Data looks like a scalar parameter, but multiple parameter names have been specified: %s"
% str(self.options.symbol_map)
)
elif len(tmp) == 0:
raise IOError("Empty range '%s'" % self.options.range)
else:
# print("_info %s" % str(self._info))
# print("SETTING DATA %s %s" % (str(tmp[0]), str(tmp[1:]))) # XXX
# print("OPTIONS %s" % str(self.options))
self._set_data(tmp[0], tmp[1:])
[docs]
def close(self):
if self._data is None and not self.db is None:
del self.db
def connect(self, connection, options, kwds={}):
try:
mod = __import__(options.using)
args = [connection]
if not options.user is None:
args.append(options.user)
if not options.password is None:
args.append(options.password)
if not options.database is None:
args.append(options.database)
return mod.connect(*args, **kwds)
except ImportError:
return None
#
# NOTE: The pyodbc interface currently doesn't work. Notably, nothing
# sets the "table" or "query" options, which causes db_table.read() to
# fail. This interface has been disabled by overriding pyodbc_available
# in sheet.py
#
[docs]
@DataManagerFactory.register('pyodbc', "%s database interface" % 'pyodbc')
class pyodbc_db_Table(db_Table):
_drivers = {
'mdb': ["Microsoft Access Driver (*.mdb)"],
'xls': [
"Microsoft Excel Driver (*.xls, *.xlsx, *.xlsm, *.xlsb)",
"Microsoft Excel Driver (*.xls)",
],
'xlsx': ["Microsoft Excel Driver (*.xls, *.xlsx, *.xlsm, *.xlsb)"],
'xlsm': ["Microsoft Excel Driver (*.xls, *.xlsx, *.xlsm, *.xlsb)"],
'xlsb': ["Microsoft Excel Driver (*.xls, *.xlsx, *.xlsm, *.xlsb)"],
'mysql': ["MySQL"],
}
_drivers['access'] = _drivers['mdb']
_drivers['excel'] = _drivers['xls']
[docs]
def __init__(self):
db_Table.__init__(self)
self.using = 'pyodbc'
[docs]
def available(self):
return pyodbc_available
def requirements(self):
return 'pyodbc'
def connect(self, connection, options):
if not options.driver is None:
ctype = options.driver
elif '.' in connection:
ctype = connection.split('.')[-1]
elif 'mysql' in connection.lower():
ctype = 'mysql'
else:
ctype = ''
extras = {}
if (
ctype in ['xls', 'xlsx', 'xlsm', 'xlsb', 'excel']
or '.xls' in connection
or '.xlsx' in connection
or '.xlsm' in connection
or '.xlsb' in connection
):
extras['autocommit'] = True
connection = self.create_connection_string(ctype, connection, options)
try:
conn = db_Table.connect(self, connection, options, extras)
except TypeError:
raise
except Exception:
e = sys.exc_info()[1]
code = e.args[0]
if code == 'IM002' or code == '08001':
if 'HOME' in os.environ:
# Need a DSN! Try to add it to $HOME/.odbc.ini ...
#
# Note: this only works on *nix platforms. It appears
# that ODBC.INI is stored in the registry on windows
#
# [JDS, 8 Oct 18]: I am not convinced that writing a
# .odbc.ini file is necessary; see the "else" branch
# below.
#
odbcIniPath = os.path.join(os.environ['HOME'], '.odbc.ini')
if os.path.exists(odbcIniPath):
shutil.copy(odbcIniPath, odbcIniPath + '.orig')
config = ODBCConfig(filename=odbcIniPath)
else:
config = ODBCConfig()
dsninfo = self.create_dsn_dict(connection, config)
dsnid = re.sub(
'[^A-Za-z0-9]', '', dsninfo['Database']
) # Strip filenames of funny characters
dsn = 'PYOMO{0}'.format(dsnid)
config.add_source(dsn, dsninfo['Driver'])
config.add_source_spec(dsn, dsninfo)
config.write(odbcIniPath)
connstr = "DRIVER={{{0}}};DSN={1}".format(dsninfo['Driver'], dsn)
else:
# Attempt to re-generate the connection string with a Driver
config = ODBCConfig()
dsninfo = self.create_dsn_dict(connection, config)
connstr = []
for k, v in dsninfo.items():
if ' ' in v and (v[0] != "{" or v[-1] != "}"):
connstr.append("%s={%s}" % (k.upper(), v))
else:
connstr.append("%s=%s" % (k.upper(), v))
connstr = ";".join(connstr)
conn = db_Table.connect(
self, connstr, options, extras
) # Will raise its own exception on failure
# Propagate the exception
else:
raise
return conn
def create_dsn_dict(self, argstr, existing_config):
result = {}
parts = argstr.split(';')
argdict = {}
for part in parts:
if len(part) > 0 and '=' in part:
key, val = part.split('=', 1)
argdict[key.lower().strip()] = val.strip()
if 'driver' in argdict:
result['Driver'] = "{0}".format(argdict['driver']).strip("{}")
if 'dsn' in argdict:
if argdict['dsn'] in existing_config.source_specs:
return existing_config.source_specs[argdict['dsn']]
else:
import logging
logger = logging.getLogger("pyomo.core")
logger.warning(
"DSN with name {0} not found. Attempting to continue with options...".format(
argdict['dsn']
)
)
if 'dbq' in argdict:
# Using a file for db access.
if 'Driver' not in result:
result['Driver'] = self._drivers[argdict['dbq'].split('.')[-1].lower()]
result['Database'] = argdict['dbq']
result['Server'] = 'localhost'
result['User'] = ''
result['Password'] = ''
result['Port'] = '5432'
result['Description'] = argdict['dbq']
for k in argdict.keys():
if k.capitalize() not in result:
result[k.capitalize()] = argdict[k]
else:
if 'Driver' not in result:
raise Exception("No driver specified, and no DBQ to infer from")
elif result['Driver'].lower() == "mysql":
result['Driver'] = "MySQL"
result['Server'] = argdict.get('server', 'localhost')
result['Database'] = argdict.get('database', '')
result['Port'] = argdict.get('port', '3306')
result['Socket'] = argdict.get('socket', '')
result['Option'] = argdict.get('option', '')
result['Stmt'] = argdict.get('stmt', '')
result['User'] = argdict.get('user', '')
result['Password'] = argdict.get('password', '')
result['Description'] = argdict.get('description', '')
else:
raise Exception(
"Unknown driver type '{0}' for database connection".format(
result['Driver']
)
)
return result
def create_connection_string(self, ctype, connection, options):
driver = self._get_driver(ctype)
if driver:
if ' ' in driver and (driver[0] != "{" or driver[-1] != "}"):
return "DRIVER={%s};Dbq=%s;" % (driver, connection)
else:
return "DRIVER=%s;Dbq=%s;" % (driver, connection)
return connection
def _get_driver(self, ctype):
# Given a list of possible drivers for this ctype, look to find
# a match in the pyodbc.drivvers() list. If a match is found,
# return it. Otherwise (arbitrarily) return the first one. If
# the ctype is not known, return None.
drivers = self._drivers.get(ctype, [])
for driver in drivers:
if driver in pyodbc.drivers():
return driver
if drivers:
return drivers[0]
else:
return None
[docs]
class ODBCError(Exception):
[docs]
def __init__(self, value):
self.parameter = value
def __repr__(self):
return repr(self.parameter)
[docs]
class ODBCConfig:
"""
Encapsulates an ODBC configuration file, usually odbc.ini or
.odbc.ini, as specified by IBM. ODBC config data can be loaded
either from a file or a string containing the relevant formatted
data. Calling load() after initialization will update existing
information in the config object with the new information.
"""
[docs]
def __init__(self, filename=None, data=None):
"""
Create a new ODBC config instance, loading data from
the given file and/or data string. Once initialized, the
new config will contain the data represented in both
arguments, if any. Data specified as a string argument
will override that in the file.
"""
# Hardcoded string required here.
# See documentation:
# https://www.ibm.com/docs/en/informix-servers/12.10?topic=SSGU8G_12.1.0/com.ibm.odbc.doc/ids_odbc_062.html
self.ODBC_DS_KEY = 'ODBC Data Sources'
self.ODBC_INFO_KEY = 'ODBC'
self.file = filename
self.sources = {}
self.source_specs = {}
self.odbc_info = {}
self.load(self.file, data)
[docs]
def load(self, filename=None, data=None):
"""
Load data from the given file and/or data string. If
both are given, data contained in the string will override
that in the file. If this config object already contains
data, the new information loaded will update the old,
replacing where keys are the same.
"""
sections = {}
if filename is not None:
with open(filename, 'r') as fileHandle:
fileData = fileHandle.read()
sections.update(self._get_sections(fileData))
if data is not None:
sections.update(self._get_sections(data))
if self.ODBC_DS_KEY in sections:
self.sources.update(sections[self.ODBC_DS_KEY])
del sections[self.ODBC_DS_KEY]
if self.ODBC_INFO_KEY in sections:
self.odbc_info.update(sections[self.ODBC_INFO_KEY])
del sections[self.ODBC_INFO_KEY]
self.source_specs.update(sections)
def __str__(self):
return "<ODBC config: {0} sources, {1} source specs>".format(
len(self.sources), len(self.source_specs)
)
def __eq__(self, other):
if isinstance(other, ODBCConfig):
return (
self.sources == other.sources
and self.source_specs == other.source_specs
and self.odbc_info == other.odbc_info
)
return False
[docs]
def odbc_repr(self):
"""
Get the full, odbc.ini-style representation of this
ODBC configuration.
"""
str = "[{0}]\n".format(self.ODBC_DS_KEY)
for name in self.sources:
str += "{0} = {1}\n".format(name, self.sources[name])
for name in self.source_specs:
str += "\n[{0}]\n".format(name)
for key in self.source_specs[name]:
str += "{0} = {1}\n".format(key, self.source_specs[name][key])
if len(self.odbc_info) > 0:
str += "\n[{0}]\n".format(self.ODBC_INFO_KEY)
for key in self.odbc_info:
str += "{0} = {1}\n".format(key, self.odbc_info[key])
return str
[docs]
def write(self, filename):
"""
Write the current ODBC configuration to the given file.
Depends on the odbc_repr() function for a string
representation of the stored ODBC config information.
"""
with open(filename, 'w') as f:
f.write(self.odbc_repr())
[docs]
def add_source(self, name, driver):
"""
Add an ODBC data source to the configuration. A data
source consists of a unique source name and a driver
string, which specifies how the source will be loaded.
If a source name is not unique, it will replace the
existing source of the same name. A source is required
in order to have a source specification.
"""
if name is None or driver is None or len(name) == 0 or len(driver) == 0:
raise ODBCError("A source must specify both a name and a driver string")
if name == self.ODBC_DS_KEY or name == self.ODBC_INFO_KEY:
raise ODBCError("A source cannot use the reserved name '{0}'".format(name))
self.sources[str(name)] = str(driver)
[docs]
def del_source(self, name):
"""
Remove an ODBC data source from the configuration. If
any source specifications are based on this source, they
will be removed as well.
"""
if name in self.sources:
if name in self.source_specs:
del self.source_specs[name]
del self.sources[name]
[docs]
def add_source_spec(self, name, spec):
"""
Add an ODBC data source specification to the configuration.
A source specification consists of a unique name and
a key-value mapping (i.e. dictionary) of options. In order
to add a source specification, a data source with a matching
name must exist in the configuration.
"""
if name is None or spec is None or len(name) == 0:
raise ODBCError(
"A source spec must specify both a name and a spec dictionary"
)
if name not in self.sources:
raise ODBCError(
"A source spec must have a corresponding source; call .add_source() first"
)
self.source_specs[name] = dict(spec)
[docs]
def del_source_spec(self, name):
"""
Remove an ODBC data source specification from the
configuration.
"""
if name in self.source_specs:
del self.source_specs[name]
[docs]
def set_odbc_info(self, key, value):
"""
Set an option for the ODBC handling specified in the
configuration. An option consists of a key-value pair.
Specifying an existing key will update the current value.
"""
if key is None or value is None or len(key) == 0 or len(value) == 0:
raise ODBCError("An ODBC info pair must specify both a key and a value")
self.odbc_info[str(key)] = str(value)
def _get_sections(self, data):
"""
Parse a string for ODBC sections. The parsing algorithm proceeds
roughly as follows:
1. Split the string on newline ('\\n') characters.
2. Remove lines consisting purely of whitespace.
3. Iterate over lines, storing all key-value pair lines in a dictionary.
4. When reaching a new section header (denoted by '[str]'), store the old
key-value pairs under the old section name. Continue from step 3.
5. On reaching end of data, store the last section and return a mapping
from section names to dictionaries of key-value pairs in those sections.
"""
sections = {}
sectionKey = None
sectionContents = {}
emptyLine = re.compile('^[ \t\r]*$')
lines = data.split('\n')
for line in lines:
if emptyLine.match(line):
# Whitespace only
pass
elif len(line) < 2:
# Not enough room for even 'k='; can't contain info
raise ODBCError(
"Malformed line in ODBC config (no meaningful data): " + line
)
elif line[0] == '[' and line[-1] == ']':
# Starts a new section; '=' has no special meaning here
sections[sectionKey] = sectionContents
sectionKey = line[1:-1]
sectionContents = {}
else:
# Not whitespace or section header; must be key=value. No duplicate '=' permitted.
if '=' not in line:
raise ODBCError(
"Malformed line in ODBC config (no key-value mapping): " + line
)
key, value = line.split("=", 1)
if '=' in value:
raise ODBCError(
"Malformed line in ODBC config (too many '='): " + line
)
sectionContents[key.strip()] = value.strip()
sections[sectionKey] = sectionContents
if None in sections:
del sections[None]
return sections
[docs]
@DataManagerFactory.register('pypyodbc', "%s database interface" % 'pypyodbc')
class pypyodbc_db_Table(pyodbc_db_Table):
[docs]
def __init__(self):
pyodbc_db_Table.__init__(self)
self.using = 'pypyodbc'
[docs]
def available(self):
return pypyodbc_available
def requirements(self):
return 'pypyodbc'
def connect(self, connection, options):
assert options['using'] == 'pypyodbc'
return pyodbc_db_Table.connect(self, connection, options)
[docs]
@DataManagerFactory.register('sqlite3', "sqlite3 database interface")
class sqlite3_db_Table(db_Table):
[docs]
def __init__(self):
db_Table.__init__(self)
self.using = 'sqlite3'
[docs]
def available(self):
return sqlite3_available
def requirements(self):
return 'sqlite3'
def connect(self, connection, options):
assert options['using'] == 'sqlite3'
filename = connection
if not os.path.exists(filename):
raise Exception("No such file: " + filename)
con = sqlite3.connect(filename)
if options.text_factory:
con.text_factory = options.text_factory
return con
[docs]
@DataManagerFactory.register('pymysql', "pymysql database interface")
class pymysql_db_Table(db_Table):
[docs]
def __init__(self):
db_Table.__init__(self)
self.using = 'pymysql'
[docs]
def available(self):
return pymysql_available
def requirements(self):
return 'pymysql'