# ____________________________________________________________________________________
#
# Pyomo: Python Optimization Modeling Objects
# Copyright (c) 2008-2026 National Technology and Engineering Solutions of Sandia, LLC
# Under the terms of Contract DE-NA0003525 with National Technology and Engineering
# Solutions of Sandia, LLC, the U.S. Government retains certain rights in this
# software. This software is distributed under the 3-clause BSD License.
# ____________________________________________________________________________________
import sys
from operator import itemgetter
from pyomo.common.deprecation import deprecation_warning
from pyomo.common.numeric_types import (
native_types,
native_numeric_types,
native_complex_types,
)
from pyomo.core.expr.numeric_expr import (
NegationExpression,
ProductExpression,
DivisionExpression,
PowExpression,
AbsExpression,
UnaryFunctionExpression,
Expr_ifExpression,
MonomialTermExpression,
LinearExpression,
SumExpression,
ExternalFunctionExpression,
mutable_expression,
)
from pyomo.core.expr.relational_expr import (
EqualityExpression,
InequalityExpression,
RangedExpression,
)
from pyomo.core.expr.visitor import StreamBasedExpressionVisitor, _EvaluationVisitor
from pyomo.core.expr import is_fixed, value
from pyomo.core.base.expression import Expression
import pyomo.core.kernel as kernel
from pyomo.repn.util import (
BeforeChildDispatcher,
ExitNodeDispatcher,
ExprType,
FileDeterminism,
FileDeterminism_to_SortComponents,
InvalidNumber,
OrderedVarRecorder,
VarRecorder,
apply_node_operation,
complex_number_error,
initialize_exit_node_dispatcher,
nan,
sum_like_expression_types,
val2str,
check_constant,
)
_CONSTANT = ExprType.CONSTANT
_FIXED = ExprType.FIXED
_VARIABLE = ExprType.VARIABLE
_LINEAR = ExprType.LINEAR
_GENERAL = ExprType.GENERAL
def _merge_dict(dest_dict, src_dict, mult, flag):
if not src_dict:
return
if flag == 1:
for vid, coef in src_dict.items():
if vid in dest_dict:
dest_dict[vid] += coef
else:
dest_dict[vid] = coef
elif not flag:
# mult is 0. There is nothing to do, unless the src_dict has an InvalidNumber
for vid, coef in src_dict.items():
if coef.__class__ is InvalidNumber:
if vid in dest_dict:
dest_dict[vid] += mult * coef
else:
dest_dict[vid] = mult * coef
else:
for vid, coef in src_dict.items():
if vid in dest_dict:
dest_dict[vid] += mult * coef
else:
dest_dict[vid] = mult * coef
[docs]
class LinearRepn:
__slots__ = ("multiplier", "constant", "linear", "nonlinear")
[docs]
def __init__(self):
self.multiplier = 1
self.constant = 0
self.linear = {}
self.nonlinear = None
def __str__(self):
linear = (
"{"
+ ", ".join(f"{val2str(k)}: {val2str(v)}" for k, v in self.linear.items())
+ "}"
)
return (
f"{self.__class__.__name__}(mult={val2str(self.multiplier)}, "
f"const={val2str(self.constant)}, "
f"linear={linear}, "
f"nonlinear={self.nonlinear})"
)
def __repr__(self):
return str(self)
@staticmethod
def constant_flag(val):
return val
@staticmethod
def multiplier_flag(val):
return val
def walker_exitNode(self):
if self.nonlinear is not None:
return _GENERAL, self
elif self.linear:
return _LINEAR, self
else:
return _CONSTANT, self.multiplier * self.constant
def duplicate(self):
ans = self.__class__.__new__(self.__class__)
ans.multiplier = self.multiplier
ans.constant = self.constant
ans.linear = dict(self.linear)
ans.nonlinear = self.nonlinear
return ans
def to_expression(self, visitor):
if self.nonlinear is not None:
# We want to start with the nonlinear term (and use
# assignment to ans instead of addition) in case the term is
# a non-numeric node (like a relational expression)
ans = self.nonlinear
else:
ans = 0
with mutable_expression() as e:
if self.linear:
var_map = visitor.var_map
for vid, coef in self.linear.items():
if self.multiplier_flag(coef):
e += coef * var_map[vid]
if self.constant_flag(self.constant):
e += self.constant
if e.nargs() > 1:
ans += e
elif e.nargs() == 1:
ans += e.arg(0)
if self.multiplier_flag(self.multiplier) != 1:
ans *= self.multiplier
return ans
[docs]
def append(self, other):
"""Append a child result from acceptChildResult
Notes
-----
This method assumes that the operator was "+". It is implemented
so that we can directly use a LinearRepn() as a `data` object in
the expression walker (thereby allowing us to use the default
implementation of acceptChildResult [which calls
`data.append()`] and avoid the function call for a custom
callback).
"""
# Note that self.multiplier will always be 1 (we only call append()
# within a sum, so there is no opportunity for self.multiplier to
# change). Omitting the assertion for efficiency.
# assert self.multiplier == 1
_type, other = other
if _type <= _FIXED: # Note: catching _FIXED and _CONSTANT
self.constant += other
return
other_mult_flag = self.multiplier_flag(other.multiplier)
if other_mult_flag == 1:
self.constant += other.constant
if other.linear:
for vid, coef in other.linear.items():
if vid in self.linear:
self.linear[vid] += coef
else:
self.linear[vid] = coef
if other.nonlinear is not None:
if self.nonlinear is None:
self.nonlinear = other.nonlinear
else:
self.nonlinear += other.nonlinear
return
mult = other.multiplier
if not other_mult_flag:
# 0 * other, so you would think that there is nothing to
# add/change about self. However, there is a chance
# that other contains an InvalidNumber, so we should go
# looking for it...
if other.constant.__class__ is InvalidNumber:
self.constant += mult * other.constant
for vid, coef in other.linear.items():
if coef.__class__ is InvalidNumber:
if vid in self.linear:
self.linear[vid] += mult * coef
else:
self.linear[vid] = mult * coef
else:
#
# mult != 0 or 1
#
if self.constant_flag(other.constant):
self.constant += mult * other.constant
if other.linear:
for vid, coef in other.linear.items():
if vid in self.linear:
self.linear[vid] += mult * coef
else:
self.linear[vid] = mult * coef
if other.nonlinear is not None:
if self.nonlinear is None:
self.nonlinear = mult * other.nonlinear
else:
self.nonlinear += mult * other.nonlinear
[docs]
def to_expression(visitor, arg):
if arg[0] <= _VARIABLE: # Note: catching _VARIABLE, _FIXED, and _CONSTANT
return arg[1]
else:
return arg[1].to_expression(visitor)
#
# NEGATION handlers
#
def _handle_negation_constant(visitor, node, arg):
return (_CONSTANT, -1 * arg[1])
def _handle_negation_fixed(visitor, node, arg):
return (_FIXED, -1 * arg[1])
def _handle_negation_ANY(visitor, node, arg):
arg[1].multiplier *= -1
return arg
#
# PRODUCT handlers
#
def _handle_product_constant_constant(visitor, node, arg1, arg2):
ans = arg1[1] * arg2[1]
if ans.__class__ is InvalidNumber:
constant_flag = visitor.Result.constant_flag
if not constant_flag(arg1[1]) or not constant_flag(arg2[1]):
a = val2str(arg1[1])
b = val2str(arg2[1])
deprecation_warning(
f"Encountered {a}*{b} in expression tree. "
"Mapping the NaN result to 0 for compatibility "
"with the lp_v1 writer. In the future, this NaN "
"will be preserved/emitted to comply with IEEE-754.",
version='6.6.0',
)
return _CONSTANT, 0
return _CONSTANT, ans
def _handle_product_fixed_fixed(visitor, node, arg1, arg2):
# This is valid for fixed * constant, and fixed * fixed
return _FIXED, arg1[1] * arg2[1]
def _handle_product_constant_ANY(visitor, node, arg1, arg2):
arg2[1].multiplier *= arg1[1]
return arg2
def _handle_product_ANY_constant(visitor, node, arg1, arg2):
arg1[1].multiplier *= arg2[1]
return arg1
def _handle_product_nonlinear(visitor, node, arg1, arg2):
# Note: the expectation is that this method is not called often: the
# Linear visitor is generally expected to be called on linear
# expressions. As such, we will not overly concern ourselves with
# performance here.
ans = visitor.Result()
if not visitor.expand_nonlinear_products:
ans.nonlinear = to_expression(visitor, arg1) * to_expression(visitor, arg2)
return _GENERAL, ans
#
# We are multiplying (and expanding) m(A + Bx + C(x)) * m(A + Bx + C(x))
_, x1 = arg1
_, x2 = arg2
# [mm]
ans.multiplier = x1.multiplier * x2.multiplier
# reset the multipliers so that to_expression doesn't re-apply them below
x1.multiplier = x2.multiplier = 1
# x1.const * x2.const [AA]
x1_const_flag = ans.constant_flag(x1.constant)
x2_const_flag = ans.constant_flag(x2.constant)
if x1_const_flag and x2_const_flag:
ans.constant = x1.constant * x2.constant
# x1.linear * x2.const [BA] + x1.const * x2.linear [AB]
_merge_dict(ans.linear, x1.linear, x2.constant, x2_const_flag)
_merge_dict(ans.linear, x2.linear, x1.constant, x1_const_flag)
NL = 0
if x2.nonlinear is not None and (
x1_const_flag or x2.nonlinear.__class__ is InvalidNumber
):
# [AC]
NL += x1.constant * x2.nonlinear
if x1.nonlinear is not None and (
x2_const_flag or x2.nonlinear.__class__ is InvalidNumber
):
# [CA]
NL += x2.constant * x1.nonlinear
# [BB] + [BC] + [CB] + [CC]
x1.constant = 0
x2.constant = 0
NL += to_expression(visitor, arg1) * to_expression(visitor, arg2)
if NL.__class__ in sum_like_expression_types and NL.nargs() == 1:
NL = NL.arg(0)
ans.nonlinear = NL
return _GENERAL, ans
#
# DIVISION handlers
#
def _handle_division_constant_constant(visitor, node, arg1, arg2):
return _CONSTANT, apply_node_operation(node, (arg1[1], arg2[1]))
def _handle_division_fixed_fixed(visitor, node, arg1, arg2):
return _FIXED, arg1[1] / arg2[1]
def _handle_division_ANY_constant(visitor, node, arg1, arg2):
repn = arg1[1]
# We can only apply the division operation (and reduce the
# multiplier to a native value) if both the multiplier is a native
# value AND the divisor is a constant. We know the latter is true
# here, but must check the former. There is also a special case if
# the divisor is 0: then we can reduce the multiplier to an
# InvalidNumber (using apply_node_operation) regardless of what the
# dividend is. Again, note that arg2 is a constant, so we can check
# it for 0 with bool()
if repn.multiplier.__class__ in native_numeric_types or not arg2[1]:
repn.multiplier = apply_node_operation(node, (repn.multiplier, arg2[1]))
else:
repn.multiplier /= arg2[1]
return arg1
def _handle_division_ANY_fixed(visitor, node, arg1, arg2):
arg1[1].multiplier /= arg2[1]
return arg1
def _handle_division_nonlinear(visitor, node, arg1, arg2):
ans = visitor.Result()
ans.nonlinear = to_expression(visitor, arg1) / to_expression(visitor, arg2)
return _GENERAL, ans
#
# EXPONENTIATION handlers
#
def _handle_pow_constant_constant(visitor, node, arg1, arg2):
ans = apply_node_operation(node, (arg1[1], arg2[1]))
if ans.__class__ in native_complex_types:
ans = complex_number_error(ans, visitor, node)
return _CONSTANT, ans
def _handle_pow_fixed_fixed(visitor, node, arg1, arg2):
return _FIXED, arg1[1] ** arg2[1]
def _handle_pow_ANY_constant(visitor, node, arg1, arg2):
_, exp = arg2
if exp == 1:
return arg1
elif exp > 1 and exp <= visitor.max_exponential_expansion and int(exp) == exp:
_type, _arg = arg1
ans = _type, _arg.duplicate()
for i in range(1, int(exp)):
ans = visitor.exit_node_dispatcher[(ProductExpression, ans[0], _type)](
visitor, None, ans, (_type, _arg.duplicate())
)
return ans
elif not exp:
return _CONSTANT, 1
return _handle_pow_nonlinear(visitor, node, arg1, arg2)
def _handle_pow_nonlinear(visitor, node, arg1, arg2):
ans = visitor.Result()
ans.nonlinear = to_expression(visitor, arg1) ** to_expression(visitor, arg2)
return _GENERAL, ans
#
# ABS and UNARY handlers
#
def _handle_unary_constant(visitor, node, arg):
ans = apply_node_operation(node, (arg[1],))
# Unary includes sqrt() which can return complex numbers
if ans.__class__ in native_complex_types:
ans = complex_number_error(ans, visitor, node)
return _CONSTANT, ans
def _handle_unary_fixed(visitor, node, arg):
return _FIXED, node.create_node_with_local_data((arg[1],))
def _handle_unary_nonlinear(visitor, node, arg):
ans = visitor.Result()
ans.nonlinear = node.create_node_with_local_data((to_expression(visitor, arg),))
return _GENERAL, ans
#
# NAMED EXPRESSION handlers
#
def _handle_named_constant(visitor, node, arg1):
# Record this common expression
visitor.subexpression_cache[id(node)] = arg1
return arg1
def _handle_named_ANY(visitor, node, arg1):
# Record this common expression
visitor.subexpression_cache[id(node)] = arg1
_type, arg1 = arg1
return _type, arg1.duplicate()
#
# EXPR_IF handlers
#
def _handle_expr_if_const(visitor, node, arg1, arg2, arg3):
_type, _test = arg1
assert _type is _CONSTANT
if _test:
if _test.__class__ is InvalidNumber:
# nan
return _handle_expr_if_nonlinear(visitor, node, arg1, arg2, arg3)
return arg2
else:
return arg3
def _handle_expr_if_nonlinear(visitor, node, arg1, arg2, arg3):
# Note: guaranteed that arg1 is not _CONSTANT
ans = visitor.Result()
ans.nonlinear = Expr_ifExpression(
(
to_expression(visitor, arg1),
to_expression(visitor, arg2),
to_expression(visitor, arg3),
)
)
return _GENERAL, ans
#
# Relational expression handlers
#
def _handle_equality_const(visitor, node, arg1, arg2):
# It is exceptionally likely that if we get here, one of the
# arguments is an InvalidNumber
args, causes = InvalidNumber.parse_args(arg1[1], arg2[1])
try:
ans = args[0] == args[1]
except:
ans = False
causes.append(str(sys.exc_info()[1]))
if causes:
ans = InvalidNumber(ans, causes)
return _CONSTANT, ans
def _handle_equality_general(visitor, node, arg1, arg2):
ans = visitor.Result()
ans.nonlinear = EqualityExpression(
(to_expression(visitor, arg1), to_expression(visitor, arg2))
)
return _GENERAL, ans
def _handle_inequality_const(visitor, node, arg1, arg2):
# It is exceptionally likely that if we get here, one of the
# arguments is an InvalidNumber
args, causes = InvalidNumber.parse_args(arg1[1], arg2[1])
try:
ans = args[0] <= args[1]
except:
ans = False
causes.append(str(sys.exc_info()[1]))
if causes:
ans = InvalidNumber(ans, causes)
return _CONSTANT, ans
def _handle_inequality_general(visitor, node, arg1, arg2):
ans = visitor.Result()
ans.nonlinear = InequalityExpression(
(to_expression(visitor, arg1), to_expression(visitor, arg2)), node.strict
)
return _GENERAL, ans
def _handle_ranged_const(visitor, node, arg1, arg2, arg3):
# It is exceptionally likely that if we get here, one of the
# arguments is an InvalidNumber
args, causes = InvalidNumber.parse_args(arg1[1], arg2[1], arg3[1])
try:
ans = args[0] <= args[1] <= args[2]
except:
ans = False
causes.append(str(sys.exc_info()[1]))
if causes:
ans = InvalidNumber(ans, causes)
return _CONSTANT, ans
def _handle_ranged_general(visitor, node, arg1, arg2, arg3):
ans = visitor.Result()
ans.nonlinear = RangedExpression(
(
to_expression(visitor, arg1),
to_expression(visitor, arg2),
to_expression(visitor, arg3),
),
node.strict,
)
return _GENERAL, ans
[docs]
def define_exit_node_handlers(_exit_node_handlers=None):
if _exit_node_handlers is None:
_exit_node_handlers = {}
_exit_node_handlers[NegationExpression] = {
None: _handle_negation_ANY,
(_CONSTANT,): _handle_negation_constant,
(_FIXED,): _handle_negation_fixed,
}
_exit_node_handlers[ProductExpression] = {
None: _handle_product_nonlinear,
(_CONSTANT, _CONSTANT): _handle_product_constant_constant,
(_CONSTANT, _FIXED): _handle_product_fixed_fixed,
(_CONSTANT, _LINEAR): _handle_product_constant_ANY,
(_CONSTANT, _GENERAL): _handle_product_constant_ANY,
(_FIXED, _CONSTANT): _handle_product_fixed_fixed,
(_FIXED, _FIXED): _handle_product_fixed_fixed,
(_FIXED, _LINEAR): _handle_product_constant_ANY,
(_FIXED, _GENERAL): _handle_product_constant_ANY,
(_LINEAR, _CONSTANT): _handle_product_ANY_constant,
(_LINEAR, _FIXED): _handle_product_ANY_constant,
(_GENERAL, _CONSTANT): _handle_product_ANY_constant,
(_GENERAL, _FIXED): _handle_product_ANY_constant,
}
_exit_node_handlers[MonomialTermExpression] = _exit_node_handlers[ProductExpression]
_exit_node_handlers[DivisionExpression] = {
None: _handle_division_nonlinear,
(_CONSTANT, _CONSTANT): _handle_division_constant_constant,
(_CONSTANT, _FIXED): _handle_division_fixed_fixed,
(_FIXED, _CONSTANT): _handle_division_fixed_fixed,
(_FIXED, _FIXED): _handle_division_fixed_fixed,
(_LINEAR, _CONSTANT): _handle_division_ANY_constant,
(_LINEAR, _FIXED): _handle_division_ANY_fixed,
(_GENERAL, _CONSTANT): _handle_division_ANY_constant,
(_GENERAL, _FIXED): _handle_division_ANY_fixed,
}
_exit_node_handlers[PowExpression] = {
None: _handle_pow_nonlinear,
(_CONSTANT, _CONSTANT): _handle_pow_constant_constant,
(_CONSTANT, _FIXED): _handle_pow_fixed_fixed,
(_FIXED, _CONSTANT): _handle_pow_fixed_fixed,
(_FIXED, _FIXED): _handle_pow_fixed_fixed,
(_LINEAR, _CONSTANT): _handle_pow_ANY_constant,
(_GENERAL, _CONSTANT): _handle_pow_ANY_constant,
}
_exit_node_handlers[UnaryFunctionExpression] = {
None: _handle_unary_nonlinear,
(_CONSTANT,): _handle_unary_constant,
(_FIXED,): _handle_unary_fixed,
}
_exit_node_handlers[AbsExpression] = _exit_node_handlers[UnaryFunctionExpression]
_exit_node_handlers[Expression] = {
None: _handle_named_ANY,
(_CONSTANT,): _handle_named_constant,
(_FIXED,): _handle_named_constant,
}
_exit_node_handlers[Expr_ifExpression] = {None: _handle_expr_if_nonlinear}
for j in (_CONSTANT, _LINEAR, _GENERAL):
for k in (_CONSTANT, _LINEAR, _GENERAL):
_exit_node_handlers[Expr_ifExpression][
_CONSTANT, j, k
] = _handle_expr_if_const
_exit_node_handlers[EqualityExpression] = {
None: _handle_equality_general,
(_CONSTANT, _CONSTANT): _handle_equality_const,
}
_exit_node_handlers[InequalityExpression] = {
None: _handle_inequality_general,
(_CONSTANT, _CONSTANT): _handle_inequality_const,
}
_exit_node_handlers[RangedExpression] = {
None: _handle_ranged_general,
(_CONSTANT, _CONSTANT, _CONSTANT): _handle_ranged_const,
}
return _exit_node_handlers
[docs]
class LinearBeforeChildDispatcher(BeforeChildDispatcher):
[docs]
def __init__(self):
# Special handling for external functions: will be handled
# as terminal nodes from the point of view of the visitor
self[ExternalFunctionExpression] = self._before_external
# Special linear / summation expressions
self[MonomialTermExpression] = self._before_monomial
self[LinearExpression] = self._before_linear
self[SumExpression] = self._before_general_expression
@staticmethod
def _before_var(visitor, child):
_id = id(child)
if _id not in visitor.var_map:
if child.fixed:
return False, (_CONSTANT, check_constant(child.value, child, visitor))
visitor.var_recorder.add(child)
ans = visitor.Result()
ans.linear[_id] = 1
return False, (_LINEAR, ans)
@staticmethod
def _before_monomial(visitor, child):
#
# The following are performance optimizations for common
# situations (Monomial terms and Linear expressions)
#
arg1, arg2 = child._args_
if arg1.__class__ not in native_types:
try:
arg1 = check_constant(visitor.evaluate(arg1), arg1, visitor)
except (ValueError, ArithmeticError):
return True, None
# We want to check / update the var_map before processing "0"
# coefficients so that we are consistent with what gets added to the
# var_map (e.g., 0*x*y: y is processed by _before_var and will
# always be added, but x is processed here)
_id = id(arg2)
if _id not in visitor.var_map:
if arg2.fixed:
return False, (
_CONSTANT,
arg1 * check_constant(arg2.value, arg2, visitor),
)
visitor.var_recorder.add(arg2)
# Trap multiplication by 0 and nan. Note that arg1 was reduced
# to a numeric value at the beginning of this method.
if not arg1:
if arg2.fixed:
arg2 = check_constant(arg2.value, arg2, visitor)
if arg2.__class__ is InvalidNumber:
deprecation_warning(
f"Encountered {arg1}*{val2str(arg2)} in expression "
"tree. Mapping the NaN result to 0 for compatibility "
"with the lp_v1 writer. In the future, this NaN "
"will be preserved/emitted to comply with IEEE-754.",
version='6.6.0',
)
return False, (_CONSTANT, arg1)
ans = visitor.Result()
ans.linear[_id] = arg1
return False, (_LINEAR, ans)
@staticmethod
def _before_linear(visitor, child):
var_map = visitor.var_map
ans = visitor.Result()
const = 0
linear = ans.linear
for arg in child.args:
if arg.__class__ is MonomialTermExpression:
arg1, arg2 = arg._args_
if arg1.__class__ not in native_types:
try:
arg1 = check_constant(visitor.evaluate(arg1), arg1, visitor)
except (ValueError, ArithmeticError):
return True, None
# Trap multiplication by 0 and nan. Note that arg1 was
# reduced to a numeric value at the beginning of this
# method.
if not arg1:
if arg2.fixed:
arg2 = check_constant(arg2.value, arg2, visitor)
if arg2.__class__ is InvalidNumber:
deprecation_warning(
f"Encountered {arg1}*{val2str(arg2)} in expression "
"tree. Mapping the NaN result to 0 for compatibility "
"with the lp_v1 writer. In the future, this NaN "
"will be preserved/emitted to comply with IEEE-754.",
version='6.6.0',
)
continue
_id = id(arg2)
if _id not in var_map:
if arg2.fixed:
const += arg1 * check_constant(arg2.value, arg2, visitor)
continue
visitor.var_recorder.add(arg2)
linear[_id] = arg1
elif _id in linear:
linear[_id] += arg1
else:
linear[_id] = arg1
elif arg.__class__ in native_numeric_types:
const += arg
elif arg.is_variable_type():
_id = id(arg)
if _id not in var_map:
if arg.fixed:
const += check_constant(arg.value, arg, visitor)
continue
visitor.var_recorder.add(arg)
linear[_id] = 1
elif _id in linear:
linear[_id] += 1
else:
linear[_id] = 1
else:
try:
const += check_constant(visitor.evaluate(arg), arg, visitor)
except (ValueError, ArithmeticError):
return True, None
if linear:
ans.constant = const
return False, (_LINEAR, ans)
else:
return False, (_CONSTANT, const)
@staticmethod
def _before_named_expression(visitor, child):
_id = id(child)
if _id in visitor.subexpression_cache:
_type, expr = visitor.subexpression_cache[_id]
if _type is _CONSTANT:
return False, (_type, expr)
else:
return False, (_type, expr.duplicate())
else:
return True, None
@staticmethod
def _before_external(visitor, child):
ans = visitor.Result()
if all(is_fixed(arg) for arg in child.args):
try:
ans.constant = check_constant(visitor.evaluate(child), child, visitor)
return False, (_CONSTANT, ans)
except:
pass
ans.nonlinear = child
return False, (_GENERAL, ans)
[docs]
class LinearRepnVisitor(StreamBasedExpressionVisitor):
Result = LinearRepn
before_child_dispatcher = LinearBeforeChildDispatcher()
exit_node_dispatcher = ExitNodeDispatcher(
initialize_exit_node_dispatcher(define_exit_node_handlers())
)
expand_nonlinear_products = False
max_exponential_expansion = 1
[docs]
def __init__(
self,
subexpression_cache,
var_map=None,
var_order=None,
sorter=None,
var_recorder=None,
):
super().__init__()
self.subexpression_cache = subexpression_cache
if any(_ is not None for _ in (var_map, var_order, sorter)):
if var_recorder is not None:
raise ValueError(
"LinearRepnVisitor: cannot specify any of var_map, "
"var_order, or sorter with var_recorder"
)
deprecation_warning(
"var_map, var_order, and sorter are deprecated arguments to "
"LinearRepnVisitor(). Please pass the VarRecorder object directly.",
version='6.8.1',
)
var_recorder = OrderedVarRecorder(var_map, var_order, sorter)
if var_recorder is None:
var_recorder = VarRecorder(
{}, FileDeterminism_to_SortComponents(FileDeterminism.ORDERED)
)
self.var_recorder = var_recorder
self.var_map = var_recorder.var_map
self._eval_expr_visitor = _EvaluationVisitor(True)
self.evaluate = self._eval_expr_visitor.dfs_postorder_stack
def initializeWalker(self, expr):
walk, result = self.beforeChild(None, expr, 0)
if not walk:
return False, self.finalizeResult(result)
return True, expr
def beforeChild(self, node, child, child_idx):
return self.before_child_dispatcher[child.__class__](self, child)
def enterNode(self, node):
# SumExpression are potentially large nary operators. Directly
# populate the result
if node.__class__ in sum_like_expression_types:
return node.args, self.Result()
else:
return node.args, []
def exitNode(self, node, data):
if data.__class__ is self.Result:
return data.walker_exitNode()
#
# General expressions...
#
return self.exit_node_dispatcher[(node.__class__, *map(itemgetter(0), data))](
self, node, *data
)
def finalizeResult(self, result):
ans = result[1]
if ans.__class__ is not self.Result:
ans = self.Result()
assert result[0] <= _FIXED # Note: allowing _FIXED or _CONSTANT
ans.constant = result[1]
return ans
mult_flag = ans.multiplier_flag(ans.multiplier)
if mult_flag == 1:
# mult is identity: only thing to do is filter out zero coefficients
self._filter_zeros(ans)
return ans
elif not mult_flag:
# the multiplier has cleared out the entire expression.
# Warn if this is suppressing a NaN (unusual, and
# non-standard, but we will wait to remove this behavior
# for the time being)
if ans.constant.__class__ is InvalidNumber or any(
c.__class__ is InvalidNumber for c in ans.linear.values()
):
deprecation_warning(
f"Encountered {ans.multiplier}*nan in expression tree. "
"Mapping the NaN result to 0 for compatibility "
"with the lp_v1 writer. In the future, this NaN "
"will be preserved/emitted to comply with IEEE-754.",
version='6.6.0',
)
return self.Result()
# mult not in {0, 1}: factor it into the constant,
# linear coefficients, and nonlinear term
self._factor_multiplier_into_ans(ans, ans.multiplier)
return ans
def _filter_zeros(self, ans):
_flag = ans.constant_flag
# Note: creating the intermediate list is important, as we are
# modifying the dict in place.
for vid in [vid for vid, c in ans.linear.items() if not _flag(c)]:
del ans.linear[vid]
def _factor_multiplier_into_ans(self, ans, mult):
_flag = ans.constant_flag
linear = ans.linear
zeros = []
for vid, coef in linear.items():
prod = mult * coef
if _flag(prod):
linear[vid] = prod
else:
zeros.append(vid)
for vid in zeros:
del linear[vid]
if ans.nonlinear is not None:
ans.nonlinear *= mult
if _flag(ans.constant):
ans.constant *= mult
ans.multiplier = 1