# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
"""
Symbolic Expression Equivalence Engine.
Provides canonicalization, structural hashing, DAG representation,
and equivalence checking for symbolic expressions without modifying
the original expression trees.
"""
from __future__ import annotations
import hashlib
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from VeraGridEngine.Utils.Symbolic.symbolic import (
Expr, Const, Var, BinOp, UnOp, Func, Func2,
_to_expr,
)
CommutativeOps = {"+", "*"}
NonSimplifiableOps = {"+", "-", "*", "/", "**"}
[docs]
class DAGNode:
"""
Node in a Directed Acyclic Graph representation of an expression.
Attributes:
op: Operator type (str for BinOp/UnOp/Func, or node type name)
children: List of child DAGNode references
structural_hash: Precomputed structural hash string
"""
__slots__ = ("op", "children", "structural_hash", "_expr")
def __init__(
self,
op: str,
children: List["DAGNode"],
structural_hash: str,
expr: Optional[Expr] = None,
):
self.op: str = op
self.children: List["DAGNode"] = children
self.structural_hash: str = structural_hash
self._expr: Optional[Expr] = expr
def __repr__(self) -> str:
if self.children:
return f"DAGNode({self.op}, [{len(self.children)} children], hash={self.structural_hash[:8]})"
return f"DAGNode({self.op}, [], hash={self.structural_hash[:8]})"
def _canonical_const(value: Any) -> Tuple[str, Any]:
"""Return canonical (type_key, normalized_value) for constants."""
if value is None:
return ("none", None)
if isinstance(value, complex):
return ("complex", (float(value.real), float(value.imag)))
if isinstance(value, float):
if value == int(value):
return ("int", int(value))
return ("float", value)
return (type(value).__name__, value)
def _flatten_assoc(expr: Expr, op: str) -> List[Expr]:
"""
Flatten associative operations.
(a + (b + c)) -> [a, b, c]
"""
if isinstance(expr, BinOp) and expr.op == op:
return _flatten_assoc(expr.left, op) + _flatten_assoc(expr.right, op)
return [expr]
def _sort_exprs(exprs: List[Expr], key_func) -> List[Expr]:
"""Sort expressions using structural hash as key for determinism."""
return sorted(exprs, key=key_func)
[docs]
def canonical(expr: Expr, depth: int = 0) -> Expr:
"""
Transform an expression into canonical form.
Applies:
- Flatten associative operators (+, *)
- Sort commutative operands
- Constant folding
- Remove neutral elements (a+0=a, a*1=a, a*0=0)
- Normalize powers (a**1=a, a**0=1)
Does NOT mutate the original expression.
:param expr: Expression to canonicalize
:param depth: Current recursion depth for cycle detection
:return: New canonical expression
"""
if depth > 100:
return expr
if isinstance(expr, Const):
return Const(expr.value)
if isinstance(expr, Var):
return Var(
name=expr.name,
reference=expr.ref,
network_conn=expr.network_conn,
shared_reference=expr.shared_ref,
non_mutable_uid=expr.non_mutable_uid,
uid=expr.uid,
diff_var=expr.diff_var,
base_var=expr.base_var,
)
if isinstance(expr, UnOp):
canonical_operand = canonical(expr.operand, depth + 1)
if expr.op == "-":
if isinstance(canonical_operand, Const):
return Const(-canonical_operand.value)
if isinstance(canonical_operand, UnOp) and canonical_operand.op == "-":
return canonical_operand.operand
return UnOp("-", canonical_operand)
return UnOp(expr.op, canonical_operand)
if isinstance(expr, BinOp):
left = canonical(expr.left, depth + 1)
right = canonical(expr.right, depth + 1)
if expr.op == "+":
terms = _flatten_assoc(BinOp(left, "+", right), "+")
const_sum = 0.0
non_const_terms = []
for t in terms:
if isinstance(t, Const) and t.value is not None:
const_sum += t.value
else:
non_const_terms.append(t)
if const_sum != 0.0 or not non_const_terms:
non_const_terms.append(Const(const_sum))
if len(non_const_terms) == 1:
return non_const_terms[0]
return _build_sum(non_const_terms)
if expr.op == "*":
factors = _flatten_assoc(BinOp(left, "*", right), "*")
zero_found = False
const_product = 1.0
non_const_factors = []
for f in factors:
if isinstance(f, Const) and f.value is not None:
if f.value == 0:
zero_found = True
break
const_product *= f.value
else:
non_const_factors.append(f)
if zero_found:
return Const(0)
if const_product != 1.0:
non_const_factors.append(Const(const_product))
if not non_const_factors:
return Const(1)
if len(non_const_factors) == 1:
return non_const_factors[0]
return _build_product(non_const_factors)
if expr.op == "-":
if isinstance(right, Const) and isinstance(left, Const):
return Const(left.value - right.value)
if isinstance(right, Const) and right.value == 0:
return left
if isinstance(left, Const) and left.value == 0:
return UnOp("-", right)
return BinOp(left, "-", right)
if expr.op == "/":
if isinstance(left, Const) and isinstance(right, Const):
if right.value != 0:
return Const(left.value / right.value)
if isinstance(right, Const) and right.value == 1:
return left
if isinstance(left, Const) and left.value == 0:
return Const(0)
return BinOp(left, "/", right)
if expr.op == "**":
if isinstance(right, Const) and right.value == 0:
return Const(1)
if isinstance(right, Const) and right.value == 1:
return left
if isinstance(left, Const) and isinstance(right, Const):
return Const(left.value ** right.value)
return BinOp(left, "**", right)
return BinOp(left, expr.op, right)
if isinstance(expr, Func):
arg_canon = canonical(expr.arg)
if isinstance(arg_canon, Const) and arg_canon.value is not None:
try:
import math
result = _eval_func(expr.op, arg_canon.value)
return Const(result)
except (ValueError, TypeError):
pass
return Func(arg_canon, expr.op)
if isinstance(expr, Func2):
arg1_canon = canonical(expr.arg1)
arg2_canon = canonical(expr.arg2)
return Func2(expr.name, arg1_canon, arg2_canon)
return expr
def _build_sum(terms: List[Expr]) -> Expr:
"""Build a canonical sum from a list of terms."""
if len(terms) == 0:
return Const(0)
if len(terms) == 1:
return terms[0]
result = terms[0]
for t in terms[1:]:
result = BinOp(result, "+", t)
return result
def _build_product(factors: List[Expr]) -> Expr:
"""Build a canonical product from a list of factors."""
if len(factors) == 0:
return Const(1)
if len(factors) == 1:
return factors[0]
result = factors[0]
for f in factors[1:]:
result = BinOp(result, "*", f)
return result
def _expand_power(expr: Expr) -> Expr:
"""
Expand a power expression to multiplication: x**2 -> x*x
:param expr: Expression that may be a power
:return: Expanded expression or original
"""
if not isinstance(expr, BinOp) or expr.op != "**":
return expr
base = expr.left
exponent = expr.right
if isinstance(exponent, Const) and isinstance(exponent.value, (int, float)):
n = int(exponent.value)
if n == 2:
return canonical(BinOp(base, "*", base))
if n >= 2 and n <= 10:
result = base
for _ in range(n - 1):
result = canonical(BinOp(result, "*", base))
return result
return expr
def _expand_product_sum(expr: Expr) -> Expr:
"""
Expand products of sums: (a+b)*(c+d) -> a*c + a*d + b*c + b*d
Also handles (a+b)^n for small integer n.
:param expr: Binary operation expression
:return: Expanded expression
"""
if not isinstance(expr, BinOp):
return expr
left = expr.left
right = expr.right
op = expr.op
if op == "**":
if isinstance(right, Const) and isinstance(right.value, (int, float)):
n = int(right.value)
if n == 2 and isinstance(left, BinOp) and left.op == "+":
a = left.left
b = left.right
a_sq = _expand_power(BinOp(a, "**", Const(2)))
b_sq = _expand_power(BinOp(b, "**", Const(2)))
ab = BinOp(a, "*", b)
ba = BinOp(b, "*", a)
two_ab = BinOp(Const(2), "*", ab)
result = BinOp(BinOp(a_sq, "+", two_ab), "+", b_sq)
return _combine_like_terms_sum(result)
if n >= 2 and n <= 10 and isinstance(left, BinOp) and left.op == "+":
result = left
for _ in range(n - 1):
result = BinOp(result, "*", left)
return _combine_like_terms_sum(expand(result))
if op == "*":
left_exp = _expand_product_sum(left)
right_exp = _expand_product_sum(right)
if isinstance(left_exp, BinOp) and left_exp.op == "+":
term1 = BinOp(left_exp.left, "*", right_exp)
term2 = BinOp(left_exp.right, "*", right_exp)
result = BinOp(term1, "+", term2)
return _combine_like_terms_sum(result)
if isinstance(right_exp, BinOp) and right_exp.op == "+":
term1 = BinOp(left_exp, "*", right_exp.left)
term2 = BinOp(left_exp, "*", right_exp.right)
result = BinOp(term1, "+", term2)
return _combine_like_terms_sum(result)
if isinstance(left_exp, BinOp) and left_exp.op == "-":
term1 = BinOp(left_exp.left, "*", right_exp)
term2 = BinOp(left_exp.right, "*", right_exp)
return canonical(BinOp(term1, "-", term2))
if isinstance(right_exp, BinOp) and right_exp.op == "-":
term1 = BinOp(left_exp, "*", right_exp.left)
term2 = BinOp(left_exp, "*", right_exp.right)
return canonical(BinOp(term1, "-", term2))
return expr
def _collect_sum_terms(expr: Expr) -> Expr:
"""Collect and combine like terms in a sum expression."""
if not isinstance(expr, BinOp) or expr.op != "+":
return expr
terms = _flatten_and_collect_terms(expr)
return _combine_terms_list(terms)
def _flatten_and_collect_terms(expr: Expr) -> List[Expr]:
"""Flatten a sum into a list of terms."""
if isinstance(expr, BinOp) and expr.op == "+":
return _flatten_and_collect_terms(expr.left) + _flatten_and_collect_terms(expr.right)
return [expr]
def _combine_terms_list(terms: List[Expr]) -> Expr:
"""Combine like terms from a list of expressions."""
term_map: Dict[str, Tuple[float, Expr]] = {}
for term in terms:
term_canon = canonical(term)
term_hash = _hash_expr(term_canon)
const_factor = _get_const_from_expr(term_canon)
non_const = _remove_const_from_expr(term_canon)
non_const_hash = _hash_expr(non_const)
if non_const_hash in term_map:
existing_const, _ = term_map[non_const_hash]
term_map[non_const_hash] = (existing_const + const_factor, non_const)
else:
term_map[non_const_hash] = (const_factor, non_const)
combined_terms = []
for const_factor, non_const in term_map.values():
if const_factor == 0:
continue
if isinstance(non_const, Const):
combined_terms.append(Const(const_factor))
else:
combined_terms.append(canonical(BinOp(Const(const_factor), "*", non_const)))
if not combined_terms:
return Const(0)
combined_terms.sort(key=_hash_expr)
result = combined_terms[0]
for t in combined_terms[1:]:
result = canonical(BinOp(result, "+", t))
return result
def _get_const_from_expr(expr: Expr) -> float:
"""Get the constant multiplier from an expression."""
if isinstance(expr, Const):
return expr.value if expr.value is not None else 1.0
if isinstance(expr, BinOp) and expr.op == "*":
total = 1.0
factors = _flatten_assoc(expr, "*")
for f in factors:
if isinstance(f, Const) and f.value is not None:
total *= f.value
else:
return total
return total
if isinstance(expr, BinOp) and expr.op == "**":
if isinstance(expr.right, Const) and expr.right.value == 2:
base_const = _get_const_from_expr(expr.left)
if isinstance(expr.left, Var):
return base_const
return base_const * base_const
return 1.0
def _remove_const_from_expr(expr: Expr) -> Expr:
"""Remove constant factors from an expression, returning the non-constant part."""
if isinstance(expr, Const):
return Const(1)
if isinstance(expr, Var):
return expr
if isinstance(expr, BinOp) and expr.op == "*":
factors = _flatten_assoc(expr, "*")
non_const_factors = []
for f in factors:
if not (isinstance(f, Const) and f.value is not None):
non_const_factors.append(_remove_const_from_expr(f))
if not non_const_factors:
return Const(1)
if len(non_const_factors) == 1:
return non_const_factors[0]
result = non_const_factors[0]
for f in non_const_factors[1:]:
result = BinOp(result, "*", f)
return result
if isinstance(expr, BinOp) and expr.op == "**":
if isinstance(expr.right, Const) and expr.right.value == 2:
base = _remove_const_from_expr(expr.left)
return canonical(BinOp(base, "*", base))
return expr
def _collect_and_combine_terms(expr: Expr) -> Expr:
"""
Collect like terms in a sum and combine them.
For example: x*y + y*x -> 2*x*y
"""
if isinstance(expr, BinOp) and expr.op == "+":
left = expr.left
right = expr.right
if isinstance(left, BinOp) and left.op == "+":
left_result = _collect_and_combine_terms(left)
right_result = _collect_and_combine_terms(right) if isinstance(right, BinOp) else right
combined = BinOp(left_result, "+", right_result)
return _collect_and_combine_terms(combined)
right_result = _collect_and_combine_terms(right) if isinstance(right, BinOp) else right
combined = BinOp(left, "+", right_result)
return _do_combine_terms(combined)
return expr
def _do_combine_terms(expr: Expr) -> Expr:
"""Helper to actually combine terms in a sum."""
if not isinstance(expr, BinOp) or expr.op != "+":
return expr
left = expr.left
right = expr.right
if isinstance(left, BinOp) and left.op == "+":
left_combined = _do_combine_terms(left)
right_combined = _do_combine_terms(right) if isinstance(right, BinOp) else right
return canonical(BinOp(left_combined, "+", right_combined))
right = _do_combine_terms(right) if isinstance(right, BinOp) else right
left_vars = _get_product_vars_sorted(left)
right_vars = _get_product_vars_sorted(right)
if left_vars and left_vars == right_vars:
left_const = _get_const_from_term(left)
right_const = _get_const_from_term(right)
if left_const is not None and right_const is not None:
total_const = left_const * right_const
non_const = _get_non_const_part(left)
combined = canonical(BinOp(Const(total_const), "*", non_const))
return combined
return canonical(BinOp(left, "+", right))
def _get_product_vars_sorted(expr: Expr) -> List[str]:
"""Get sorted list of variable names in a term (product or single var)."""
if isinstance(expr, Var):
return [expr.name]
if isinstance(expr, Const):
return []
if isinstance(expr, BinOp) and expr.op == "*":
return sorted(_get_product_vars_sorted(expr.left) + _get_product_vars_sorted(expr.right))
if isinstance(expr, BinOp) and expr.op == "**":
return _get_product_vars_sorted(expr.left) * int(expr.right.value) if isinstance(expr.right, Const) else []
return []
def _get_const_from_term(expr: Expr) -> Optional[float]:
"""Get constant factor from a term (product or single const)."""
if isinstance(expr, Const):
return expr.value
if isinstance(expr, BinOp) and expr.op == "*":
consts = []
if isinstance(expr.left, Const):
consts.append(expr.left.value)
elif isinstance(expr.right, Const):
consts.append(expr.right.value)
else:
consts_left = _get_const_from_term(expr.left)
consts_right = _get_const_from_term(expr.right)
if consts_left is not None:
consts.append(consts_left)
if consts_right is not None:
consts.append(consts_right)
if consts:
result = 1.0
for c in consts:
result *= c
return result
return None
def _get_non_const_part(expr: Expr) -> Expr:
"""Get the non-constant part of a term."""
if isinstance(expr, Const):
return Const(1)
if isinstance(expr, Var):
return expr
if isinstance(expr, BinOp) and expr.op == "*":
left_part = _get_non_const_part(expr.left) if not isinstance(expr.left, Const) else None
right_part = _get_non_const_part(expr.right) if not isinstance(expr.right, Const) else None
if left_part is None and right_part is None:
return Const(1)
if left_part is None:
return right_part
if right_part is None:
return left_part
return BinOp(left_part, "*", right_part)
return expr
def _get_product_vars(expr: Expr) -> List[str]:
"""Get sorted list of variable names in a product."""
if isinstance(expr, Var):
return [expr.name]
if isinstance(expr, BinOp) and expr.op == "*":
return _get_product_vars(expr.left) + _get_product_vars(expr.right)
return []
def _get_const_factor_in_product(expr: Expr) -> Optional[float]:
"""Get constant factor from a product expression."""
if isinstance(expr, Const):
return expr.value
if isinstance(expr, BinOp) and expr.op == "*":
if isinstance(expr.left, Const) and expr.left.value is not None:
return expr.left.value
if isinstance(expr.right, Const) and expr.right.value is not None:
return expr.right.value
return None
def _remove_const_from_product(expr: Expr) -> Expr:
"""Remove constant factor from a product, returning the non-constant part."""
if isinstance(expr, BinOp) and expr.op == "*":
if isinstance(expr.left, Const):
return expr.right
if isinstance(expr.right, Const):
return expr.left
return expr
left = expr.left
right = expr.right
if isinstance(left, BinOp) and left.op == "+":
left_combined = _collect_and_combine_terms(left)
right_combined = _collect_and_combine_terms(right)
combined = BinOp(left_combined, "+", right_combined)
return _collect_and_combine_terms(combined)
right_combined = _collect_and_combine_terms(right) if isinstance(right, BinOp) else right
left_hash = _hash_expr(left)
right_hash = _hash_expr(right)
left_vars = _get_product_vars(left)
right_vars = _get_product_vars(right)
if left_vars and str(sorted(left_vars)) == str(sorted(right_vars)):
left_const = _get_const_factor_in_product(left)
right_const = _get_const_factor_in_product(right)
if left_const is not None and right_const is not None:
new_const = Const(left_const * right_const)
left_non_const = _remove_const_from_product(left)
right_non_const = _remove_const_from_product(right)
if _hash_expr(left_non_const) == _hash_expr(right_non_const):
combined_product = canonical(BinOp(new_const, "*", left_non_const))
return combined_product
return canonical(BinOp(left, "+", right_combined))
def _get_product_vars(expr: Expr) -> List[str]:
"""Get sorted list of variable names in a product."""
if isinstance(expr, Var):
return [expr.name]
if isinstance(expr, BinOp) and expr.op == "*":
return _get_product_vars(expr.left) + _get_product_vars(expr.right)
return []
def _get_const_factor_in_product(expr: Expr) -> Optional[float]:
"""Get constant factor from a product expression."""
if isinstance(expr, Const):
return expr.value
if isinstance(expr, BinOp) and expr.op == "*":
if isinstance(expr.left, Const) and expr.left.value is not None:
return expr.left.value
if isinstance(expr.right, Const) and expr.right.value is not None:
return expr.right.value
return None
def _remove_const_from_product(expr: Expr) -> Expr:
"""Remove constant factor from a product, returning the non-constant part."""
if isinstance(expr, BinOp) and expr.op == "*":
if isinstance(expr.left, Const):
return expr.right
if isinstance(expr.right, Const):
return expr.left
return expr
left = expr.left
right = expr.right
op = expr.op
if op == "**":
expanded_base = _expand_product_sum(left)
if expanded_base is not left:
new_expr = BinOp(expanded_base, op, right)
else:
new_expr = expr
if isinstance(right, Const) and isinstance(right.value, (int, float)):
n = int(right.value)
if n == 2 and isinstance(left, BinOp) and left.op == "+":
a = left.left
b = left.right
a_sq = _expand_power(BinOp(a, "**", Const(2)))
b_sq = _expand_power(BinOp(b, "**", Const(2)))
a_b = canonical(BinOp(a, "*", b))
b_a = canonical(BinOp(b, "*", a))
two_ab = canonical(BinOp(Const(2), "*", a_b))
return canonical(BinOp(BinOp(a_sq, "+", two_ab), "+", b_sq))
if n >= 2 and n <= 10 and isinstance(left, BinOp) and left.op == "+":
result = left
for _ in range(n - 1):
result = canonical(BinOp(result, "*", left))
return result
if op == "*":
left_exp = _expand_product_sum(left)
right_exp = _expand_product_sum(right)
if isinstance(left_exp, BinOp) and left_exp.op == "+":
term1 = BinOp(left_exp.left, "*", right_exp)
term2 = BinOp(left_exp.right, "*", right_exp)
return canonical(BinOp(term1, "+", term2))
if isinstance(right_exp, BinOp) and right_exp.op == "+":
term1 = BinOp(left_exp, "*", right_exp.left)
term2 = BinOp(left_exp, "*", right_exp.right)
return canonical(BinOp(term1, "+", term2))
if isinstance(left_exp, BinOp) and left_exp.op == "-":
term1 = BinOp(left_exp.left, "*", right_exp)
term2 = BinOp(left_exp.right, "*", right_exp)
return canonical(BinOp(term1, "-", term2))
if isinstance(right_exp, BinOp) and right_exp.op == "-":
term1 = BinOp(left_exp, "*", right_exp.left)
term2 = BinOp(left_exp, "*", right_exp.right)
return canonical(BinOp(term1, "-", term2))
expanded_left = _expand_power(BinOp(left_exp, "**", Const(2)))
if expanded_left is not BinOp(left_exp, "**", Const(2)):
return canonical(BinOp(expanded_left, "*", right_exp))
return expr
def _get_sorted_vars_key(expr: Expr) -> Tuple[str, ...]:
"""Get a canonical key (sorted tuple) for variable comparison in products."""
return tuple(sorted(_get_product_vars(expr)))
def _combine_like_terms_sum(expr: Expr) -> Expr:
"""Combine like terms in a sum by collecting all terms first."""
if not isinstance(expr, BinOp) or expr.op != "+":
return expr
all_terms = _collect_all_terms_sum(expr)
return _merge_terms_list_sum(all_terms)
def _collect_all_terms_sum(expr: Expr) -> List[Expr]:
"""Collect all terms from a sum into a flat list."""
if isinstance(expr, BinOp) and expr.op == "+":
return _collect_all_terms_sum(expr.left) + _collect_all_terms_sum(expr.right)
return [expr]
def _merge_terms_list_sum(terms: List[Expr]) -> Expr:
"""Merge like terms from a flat list of terms."""
term_map: Dict[Tuple[str, ...], Tuple[float, Expr]] = {}
for term in terms:
key = _get_sorted_vars_key(term)
const = _get_const_factor_in_term(term) or 1.0
if key in term_map:
existing_const, _ = term_map[key]
term_map[key] = (existing_const + const, term)
else:
term_map[key] = (const, term)
if not term_map:
return Const(0)
merged = []
for const, term in term_map.values():
if const == 0:
continue
non_const = _get_non_const_part(term)
if isinstance(non_const, Const):
merged.append(Const(const))
elif const == 1.0:
merged.append(non_const)
else:
merged.append(canonical(BinOp(Const(const), "*", non_const)))
merged.sort(key=lambda e: _hash_expr(e))
result = merged[0]
for t in merged[1:]:
result = canonical(BinOp(result, "+", t))
return result
def _get_const_factor_in_term(expr: Expr) -> float:
"""Get constant factor from a term (product or single const)."""
if isinstance(expr, Const):
return expr.value if expr.value is not None else 1.0
if isinstance(expr, BinOp) and expr.op == "*":
consts = []
if isinstance(expr.left, Const) and expr.left.value is not None:
consts.append(expr.left.value)
if isinstance(expr.right, Const) and expr.right.value is not None:
consts.append(expr.right.value)
if consts:
result = 1.0
for c in consts:
result *= c
return result
return 1.0
def _sort_binop_args(expr: BinOp) -> Expr:
"""Sort arguments of a multiplication for canonical comparison."""
if expr.op == "*":
left_h = _hash_expr(expr.left)
right_h = _hash_expr(expr.right)
if left_h > right_h:
return BinOp(expr.right, expr.op, expr.left)
return expr
def _get_const_factor(expr: BinOp) -> Optional[float]:
"""Get constant factor from a multiplication expression."""
if expr.op == "*":
if isinstance(expr.left, Const) and expr.left.value is not None:
return expr.left.value
if isinstance(expr.right, Const) and expr.right.value is not None:
return expr.right.value
return None
def _remove_const_factor(expr: BinOp) -> Expr:
"""Remove constant factor from a multiplication expression."""
if expr.op == "*":
if isinstance(expr.left, Const):
return expr.right
if isinstance(expr.right, Const):
return expr.left
return expr
def _try_expand(expr: Expr, max_terms: int = 200) -> Expr:
"""
Try to expand an expression once. Returns the expanded form or the original.
"""
if _count_terms(expr) > max_terms:
return expr
if isinstance(expr, BinOp):
left_expanded = _try_expand(expr.left, max_terms)
right_expanded = _try_expand(expr.right, max_terms)
new_expr = BinOp(left_expanded, expr.op, right_expanded)
if new_expr.op == "*":
expanded = _expand_product_sum(new_expr)
if expanded is not new_expr:
return expanded
if new_expr.op == "**":
pow_expanded = _expand_power(new_expr)
if pow_expanded is not new_expr:
return pow_expanded
return new_expr
if isinstance(expr, UnOp):
operand_expanded = _try_expand(expr.operand, max_terms)
if operand_expanded is not expr.operand:
return UnOp(expr.op, operand_expanded)
return expr
if isinstance(expr, Func):
arg_expanded = _try_expand(expr.arg, max_terms)
if arg_expanded is not expr.arg:
return Func(arg_expanded, expr.op)
return expr
if isinstance(expr, Func2):
arg1_expanded = _try_expand(expr.arg1, max_terms)
arg2_expanded = _try_expand(expr.arg2, max_terms)
if arg1_expanded is not expr.arg1 or arg2_expanded is not expr.arg2:
return Func2(expr.name, arg1_expanded, arg2_expanded)
return expr
return expr
def _count_terms(e: Expr) -> int:
"""Quickly estimate term count without full canonicalization."""
if isinstance(e, (Const, Var)):
return 1
if isinstance(e, UnOp):
return _count_terms(e.operand)
if isinstance(e, Func):
return _count_terms(e.arg)
if isinstance(e, Func2):
return _count_terms(e.arg1) + _count_terms(e.arg2)
if isinstance(e, BinOp):
if e.op in ("+", "-"):
return _count_terms(e.left) + _count_terms(e.right)
if e.op == "*":
left_terms = _count_terms(e.left)
right_terms = _count_terms(e.right)
if left_terms > 1 and right_terms > 1:
return left_terms * right_terms
return left_terms + right_terms
return _count_terms(e.left) + _count_terms(e.right)
return 1
def _expand_once(e: Expr, max_terms: int, depth: int) -> Expr:
if depth > 20:
return e
if _count_terms(e) > max_terms:
return e
return _try_expand(e, max_terms)
def _canonical_once(e: Expr, depth: int) -> Expr:
if depth > 50:
return e
return canonical(e, depth=depth)
[docs]
def expand(expr: Expr, max_terms: int = 200) -> Expr:
"""
Fully expand an algebraic expression.
Applies:
- Product of sums expansion: (a+b)*(c+d) -> a*c + a*d + b*c + b*d
- Power of sum expansion: (a+b)^2 -> a^2 + 2*a*b + b^2
:param expr: Expression to expand
:param max_terms: Maximum number of terms before bailing out
:return: Fully expanded expression
"""
result = expr
for _ in range(3):
if _count_terms(result) > max_terms:
break
expanded = _expand_once(result, max_terms, 0)
if str(expanded) == str(result):
break
result = expanded
return result
[docs]
def expand_and_canonicalize(expr: Expr, max_terms: int = 500) -> Expr:
"""
Expand and then canonicalize an expression.
This is the recommended function for getting a canonical form
that handles algebraic expansion.
:param expr: Expression to process
:param max_terms: Maximum number of terms before bailing out (prevents exponential blowup)
:return: Expanded and canonicalized expression
"""
if _count_terms(expr) > max_terms:
return expr
result = _expand_once(expr, max_terms, 0)
for _ in range(3):
canonicalized = _canonical_once(result, 0)
if str(canonicalized) == str(result):
break
result = canonicalized
expanded2 = _expand_once(result, max_terms, 0)
if str(expanded2) == str(result):
break
result = expanded2
return result
[docs]
def structural_hash_expanded(expr: Expr) -> str:
"""
Compute structural hash with algebraic expansion.
This version expands expressions before hashing, so that
(x+1)^2 and x^2+2x+1 produce the same hash.
:param expr: Expression to hash
:return: Hex string hash (32 chars)
"""
canon = expand_and_canonicalize(expr)
return _hash_expr(canon)
[docs]
def equivalent_expanded(e1: Expr, e2: Expr) -> bool:
"""
Check if two expressions are mathematically equivalent using expansion.
This version uses algebraic expansion before comparison, so that
(x+1)^2 and x^2+2x+1 are considered equivalent.
:param e1: First expression
:param e2: Second expression
:return: True if equivalent
"""
if e1 is e2:
return True
h1 = structural_hash_expanded(e1)
h2 = structural_hash_expanded(e2)
return h1 == h2
def _eval_func(op: str, value: float) -> float:
"""Evaluate a unary function at a constant value."""
import math
if op == "sin":
return math.sin(value)
if op == "cos":
return math.cos(value)
if op == "tan":
return math.tan(value)
if op == "exp":
return math.exp(value)
if op == "log":
return math.log(value)
if op == "sqrt":
return math.sqrt(value)
if op == "asin":
return math.asin(value)
if op == "acos":
return math.acos(value)
if op == "atan":
return math.atan(value)
if op == "sinh":
return math.sinh(value)
if op == "cosh":
return math.cosh(value)
if op == "tanh":
return math.tanh(value)
if op == "abs":
return abs(value)
if op == "floor":
return math.floor(value)
if op == "ceil":
return math.ceil(value)
if op == "heaviside":
return 0.0 if value <= 0 else 1.0
raise ValueError(f"Cannot evaluate function {op}")
[docs]
def structural_hash(expr: Expr) -> str:
"""
Compute a deterministic structural hash for an expression.
Properties:
- Independent of node UIDs
- Order-invariant for commutative operators (+, *)
- Includes operator type and structure
:param expr: Expression to hash
:return: Hex string hash (32 chars)
"""
canon = canonical(expr)
return _hash_expr(canon)
def _hash_expr(expr: Expr) -> str:
"""Recursively compute hash of a canonical expression."""
if isinstance(expr, Const):
type_key, norm_val = _canonical_const(expr.value)
h = hashlib.sha256()
h.update(f"Const:{type_key}:{norm_val}".encode())
return h.hexdigest()
if isinstance(expr, Var):
h = hashlib.sha256()
h.update(f"Var:{expr.name}".encode())
return h.hexdigest()
if isinstance(expr, UnOp):
h = hashlib.sha256()
h.update(f"UnOp:{expr.op}:{_hash_expr(expr.operand)}".encode())
return h.hexdigest()
if isinstance(expr, BinOp):
h = hashlib.sha256()
child_hash = _hash_expr(expr.left) + _hash_expr(expr.right)
if expr.op in CommutativeOps:
child_hash = "".join(sorted(child_hash))
h.update(f"BinOp:{expr.op}:{child_hash}".encode())
return h.hexdigest()
if isinstance(expr, Func):
h = hashlib.sha256()
h.update(f"Func:{expr.op}:{_hash_expr(expr.arg)}".encode())
return h.hexdigest()
if isinstance(expr, Func2):
h = hashlib.sha256()
arg_hash = _hash_expr(expr.arg1) + _hash_expr(expr.arg2)
if expr.name in {"min", "max", "atan2"}:
arg_hash = "".join(sorted([_hash_expr(expr.arg1), _hash_expr(expr.arg2)]))
h.update(f"Func2:{expr.name}:{arg_hash}".encode())
return h.hexdigest()
raise TypeError(f"Unknown expression type: {type(expr)}")
[docs]
def to_dag(expr: Expr, memo: Optional[Dict[str, DAGNode]] = None) -> DAGNode:
"""
Convert an expression to a DAG with deduplication.
Nodes are deduplicated by structural hash, so identical
subexpressions share the same DAG node.
:param expr: Expression to convert
:param memo: Optional memoization dict (structural_hash -> DAGNode)
:return: DAGNode root
"""
if memo is None:
memo = {}
canon = canonical(expr)
return _to_dag_recursive(canon, memo)
def _to_dag_recursive(expr: Expr, memo: Dict[str, DAGNode]) -> DAGNode:
"""Recursively build DAG with memoization."""
h = _hash_expr(expr)
if h in memo:
return memo[h]
if isinstance(expr, Const):
node = DAGNode(op="Const", children=[], structural_hash=h, expr=expr)
memo[h] = node
return node
if isinstance(expr, Var):
node = DAGNode(op=f"Var:{expr.name}", children=[], structural_hash=h, expr=expr)
memo[h] = node
return node
if isinstance(expr, UnOp):
child = _to_dag_recursive(expr.operand, memo)
node = DAGNode(op=f"UnOp:{expr.op}", children=[child], structural_hash=h, expr=expr)
memo[h] = node
return node
if isinstance(expr, BinOp):
left = _to_dag_recursive(expr.left, memo)
right = _to_dag_recursive(expr.right, memo)
children = [left, right]
if expr.op in CommutativeOps:
children = sorted(children, key=lambda n: n.structural_hash)
node = DAGNode(op=f"BinOp:{expr.op}", children=children, structural_hash=h, expr=expr)
memo[h] = node
return node
if isinstance(expr, Func):
arg = _to_dag_recursive(expr.arg, memo)
node = DAGNode(op=f"Func:{expr.op}", children=[arg], structural_hash=h, expr=expr)
memo[h] = node
return node
if isinstance(expr, Func2):
arg1 = _to_dag_recursive(expr.arg1, memo)
arg2 = _to_dag_recursive(expr.arg2, memo)
children = [arg1, arg2]
if expr.name in {"min", "max", "atan2"}:
children = sorted(children, key=lambda n: n.structural_hash)
node = DAGNode(op=f"Func2:{expr.name}", children=children, structural_hash=h, expr=expr)
memo[h] = node
return node
raise TypeError(f"Unknown expression type: {type(expr)}")
[docs]
def equivalent(e1: Expr, e2: Expr) -> bool:
"""
Check if two expressions are mathematically equivalent.
Uses canonicalization and structural hashing for O(1) comparison
after canonical form is computed.
:param e1: First expression
:param e2: Second expression
:return: True if equivalent
"""
if e1 is e2:
return True
h1 = structural_hash(e1)
h2 = structural_hash(e2)
return h1 == h2
[docs]
def equivalent_systems(sys1: List[Expr], sys2: List[Expr]) -> bool:
"""
Check if two systems of equations are equivalent.
Systems are equivalent if they contain the same equations
in any order. Each equation is compared using structural
canonical hashing.
:param sys1: First list of equations
:param sys2: Second list of equations
:return: True if equivalent systems
"""
if len(sys1) != len(sys2):
return False
hashes1 = sorted(structural_hash(e) for e in sys1)
hashes2 = sorted(structural_hash(e) for e in sys2)
return hashes1 == hashes2
[docs]
def equations_equivalent(eq1: Expr, eq2: Expr) -> bool:
"""
Check if two equations are equivalent.
An equation is treated as a BinOp with operator "=".
Two equations are equivalent if their canonical forms match.
:param eq1: First equation (typically BinOp with op="=")
:param eq2: Second equation
:return: True if equivalent
"""
if not isinstance(eq1, BinOp) or not isinstance(eq2, BinOp):
return structural_hash(eq1) == structural_hash(eq2)
lhs1_hash = structural_hash(eq1.left)
lhs2_hash = structural_hash(eq2.left)
rhs1_hash = structural_hash(eq1.right)
rhs2_hash = structural_hash(eq2.right)
if lhs1_hash == lhs2_hash and rhs1_hash == rhs2_hash:
return True
if lhs1_hash == rhs2_hash and rhs1_hash == lhs2_hash:
return True
diff1 = canonical(BinOp(eq1.left, "-", eq1.right))
diff2 = canonical(BinOp(eq2.left, "-", eq2.right))
return structural_hash(diff1) == structural_hash(diff2)
[docs]
def get_structural_hash(expr: Expr) -> str:
"""
Get the structural hash of an expression.
This is a convenience alias for structural_hash().
:param expr: Expression to hash
:return: Hex string hash
"""
return structural_hash(expr)
[docs]
def dag_to_string(node: DAGNode, indent: int = 0) -> str:
"""Generate a string representation of a DAG for debugging."""
prefix = " " * indent
if not node.children:
return f"{prefix}{node.op}"
children_str = [dag_to_string(c, indent + 1) for c in node.children]
return f"{prefix}{node.op} -> [\n" + ",\n".join(children_str) + f"\n{prefix}]"
[docs]
def simplify_deep(expr: Expr) -> Expr:
"""
Deep simplification with algebraic expansion and canonicalization.
Unlike the built-in simplify(), this applies expansion and
canonicalization recursively and handles more cases including
polynomial equivalence.
:param expr: Expression to simplify
:return: Simplified expression
"""
result = expand_and_canonicalize(expr)
prev_str = str(result)
for _ in range(10):
new_result = expand_and_canonicalize(result)
if str(new_result) == prev_str:
break
prev_str = str(new_result)
result = new_result
return result
__all__ = [
"canonical",
"structural_hash",
"structural_hash_expanded",
"to_dag",
"DAGNode",
"equivalent",
"equivalent_expanded",
"equivalent_systems",
"equations_equivalent",
"get_canonical_form",
"get_structural_hash",
"dag_to_string",
"simplify_deep",
"expand",
"expand_and_canonicalize",
"_hash_expr",
"variables_in_corresponding_attributes",
]