# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
"""
Standalone tests for symbolic expression equivalence system.
This test file imports ONLY the minimal classes needed and does not
trigger the full VeraGridEngine import chain.
"""
import sys
import os
import hashlib
from typing import Any, Dict, List, Optional, Set, Tuple, Union
sys.path.insert(0, '/home/marina/PycharmProjects/VeraGrid/src')
[docs]
class Const:
__slots__ = ("value", "name", "uid")
def __init__(self, value=None, uid=None, name=""):
self.uid = uid if uid is not None else id(self)
self.value = value
self.name = name
def __str__(self):
return str(self.value)
def __repr__(self):
return self.__str__()
[docs]
class Var:
__slots__ = ("name", "uid")
def __init__(self, name, uid=None):
self.name = name
self.uid = uid if uid is not None else id(self)
def __str__(self):
return self.name
def __repr__(self):
return self.name
[docs]
class BinOp:
__slots__ = ("op", "left", "right", "uid")
def __init__(self, left, op, right, uid=None):
self.op = op
self.left = left
self.right = right
self.uid = uid if uid is not None else id(self)
def __str__(self):
return f"({self.left} {self.op} {self.right})"
def __repr__(self):
return self.__str__()
[docs]
class UnOp:
__slots__ = ("op", "operand", "uid")
def __init__(self, op, operand, uid=None):
self.op = op
self.operand = operand
self.uid = uid if uid is not None else id(self)
def __str__(self):
return f"{self.op}({self.operand})"
def __repr__(self):
return self.__str__()
[docs]
class Func:
__slots__ = ("op", "arg", "uid")
def __init__(self, arg, op="", uid=None):
self.op = op
self.arg = arg
self.uid = uid if uid is not None else id(self)
def __str__(self):
return f"{self.op}({self.arg})"
def __repr__(self):
return self.__str__()
def _to_expr(val):
if isinstance(val, (int, float)):
return Const(val)
return val
def _add(self, other):
return BinOp(self, "+", _to_expr(other))
def _radd(self, other):
return BinOp(_to_expr(other), "+", self)
def _sub(self, other):
return BinOp(self, "-", _to_expr(other))
def _rsub(self, other):
return BinOp(_to_expr(other), "-", self)
def _mul(self, other):
return BinOp(self, "*", _to_expr(other))
def _rmul(self, other):
return BinOp(_to_expr(other), "*", self)
def _truediv(self, other):
return BinOp(self, "/", _to_expr(other))
def _rtruediv(self, other):
return BinOp(_to_expr(other), "/", self)
def _pow(self, other):
return BinOp(self, "**", _to_expr(other))
def _rpow(self, other):
return BinOp(_to_expr(other), "**", self)
def _neg(self):
return UnOp("-", self)
Const.__add__ = _add
Const.__radd__ = _radd
Const.__sub__ = _sub
Const.__rsub__ = _rsub
Const.__mul__ = _mul
Const.__rmul__ = _rmul
Const.__truediv__ = _truediv
Const.__rtruediv__ = _rtruediv
Const.__pow__ = _pow
Const.__rpow__ = _rpow
Const.__neg__ = _neg
Var.__add__ = _add
Var.__radd__ = _radd
Var.__sub__ = _sub
Var.__rsub__ = _rsub
Var.__mul__ = _mul
Var.__rmul__ = _rmul
Var.__truediv__ = _truediv
Var.__rtruediv__ = _rtruediv
Var.__pow__ = _pow
Var.__rpow__ = _rpow
Var.__neg__ = _neg
BinOp.__add__ = _add
BinOp.__radd__ = _radd
BinOp.__sub__ = _sub
BinOp.__rsub__ = _rsub
BinOp.__mul__ = _mul
BinOp.__rmul__ = _rmul
BinOp.__truediv__ = _truediv
BinOp.__rtruediv__ = _rtruediv
BinOp.__pow__ = _pow
BinOp.__rpow__ = _rpow
BinOp.__neg__ = _neg
UnOp.__add__ = _add
UnOp.__radd__ = _radd
UnOp.__sub__ = _sub
UnOp.__rsub__ = _rsub
UnOp.__mul__ = _mul
UnOp.__rmul__ = _rmul
UnOp.__truediv__ = _truediv
UnOp.__rtruediv__ = _rtruediv
UnOp.__pow__ = _pow
UnOp.__rpow__ = _rpow
UnOp.__neg__ = _neg
[docs]
def sin(x):
return Func(x, "sin")
[docs]
def cos(x):
return Func(x, "cos")
CommutativeOps = {"+", "*"}
def _canonical_const(value):
if value is None:
return ("none", None)
if isinstance(value, complex):
return ("complex", (float(value.real), float(value.imag)))
if isinstance(value, float):
if value == int(value):
return ("int", int(value))
return ("float", value)
return (type(value).__name__, value)
def _flatten_assoc(expr, op):
if isinstance(expr, BinOp) and expr.op == op:
return _flatten_assoc(expr.left, op) + _flatten_assoc(expr.right, op)
return [expr]
def _sort_exprs(exprs, key_func):
return sorted(exprs, key=key_func)
def _hash_expr(expr):
if isinstance(expr, Const):
type_key, norm_val = _canonical_const(expr.value)
h = hashlib.sha256()
h.update(f"Const:{type_key}:{norm_val}".encode())
return h.hexdigest()
if isinstance(expr, Var):
h = hashlib.sha256()
h.update(f"Var:{expr.name}".encode())
return h.hexdigest()
if isinstance(expr, UnOp):
h = hashlib.sha256()
h.update(f"UnOp:{expr.op}:{_hash_expr(expr.operand)}".encode())
return h.hexdigest()
if isinstance(expr, BinOp):
h = hashlib.sha256()
child_hash = _hash_expr(expr.left) + _hash_expr(expr.right)
if expr.op in CommutativeOps:
child_hash = "".join(sorted(child_hash))
h.update(f"BinOp:{expr.op}:{child_hash}".encode())
return h.hexdigest()
if isinstance(expr, Func):
h = hashlib.sha256()
h.update(f"Func:{expr.op}:{_hash_expr(expr.arg)}".encode())
return h.hexdigest()
raise TypeError(f"Unknown expression type: {type(expr)}")
def _build_sum(terms):
sorted_terms = _sort_exprs(terms, _hash_expr)
result = sorted_terms[0]
for t in sorted_terms[1:]:
result = BinOp(result, "+", t)
return result
def _build_product(factors):
sorted_factors = _sort_exprs(factors, _hash_expr)
result = sorted_factors[0]
for f in sorted_factors[1:]:
result = BinOp(result, "*", f)
return result
def _eval_func(op, value):
import math
if op == "sin":
return math.sin(value)
if op == "cos":
return math.cos(value)
if op == "tan":
return math.tan(value)
if op == "exp":
return math.exp(value)
if op == "log":
return math.log(value)
if op == "sqrt":
return math.sqrt(value)
if op == "abs":
return abs(value)
if op == "floor":
return math.floor(value)
if op == "ceil":
return math.ceil(value)
if op == "heaviside":
return 0.0 if value <= 0 else 1.0
raise ValueError(f"Cannot evaluate function {op}")
[docs]
def canonical(expr):
if isinstance(expr, Const):
return Const(expr.value)
if isinstance(expr, Var):
return Var(expr.name)
if isinstance(expr, UnOp):
canonical_operand = canonical(expr.operand)
if expr.op == "-":
if isinstance(canonical_operand, Const):
return Const(-canonical_operand.value)
if isinstance(canonical_operand, UnOp) and canonical_operand.op == "-":
return canonical_operand.operand
return UnOp("-", canonical_operand)
return UnOp(expr.op, canonical_operand)
if isinstance(expr, BinOp):
left = canonical(expr.left)
right = canonical(expr.right)
if expr.op == "+":
terms = _flatten_assoc(BinOp(left, "+", right), "+")
const_sum = 0.0
non_const_terms = []
for t in terms:
if isinstance(t, Const) and t.value is not None:
const_sum += t.value
else:
non_const_terms.append(t)
if const_sum != 0.0 or not non_const_terms:
non_const_terms.append(Const(const_sum))
if len(non_const_terms) == 1:
return non_const_terms[0]
return _build_sum(non_const_terms)
if expr.op == "*":
factors = _flatten_assoc(BinOp(left, "*", right), "*")
zero_found = False
const_product = 1.0
non_const_factors = []
for f in factors:
if isinstance(f, Const) and f.value is not None:
if f.value == 0:
zero_found = True
break
const_product *= f.value
else:
non_const_factors.append(f)
if zero_found:
return Const(0)
if const_product != 1.0:
non_const_factors.append(Const(const_product))
if not non_const_factors:
return Const(1)
if len(non_const_factors) == 1:
return non_const_factors[0]
return _build_product(non_const_factors)
if expr.op == "-":
if isinstance(right, Const) and isinstance(left, Const):
return Const(left.value - right.value)
if isinstance(right, Const) and right.value == 0:
return left
if isinstance(left, Const) and left.value == 0:
return UnOp("-", right)
return BinOp(left, "-", right)
if expr.op == "/":
if isinstance(left, Const) and isinstance(right, Const):
if right.value != 0:
return Const(left.value / right.value)
if isinstance(right, Const) and right.value == 1:
return left
if isinstance(left, Const) and left.value == 0:
return Const(0)
return BinOp(left, "/", right)
if expr.op == "**":
if isinstance(right, Const) and right.value == 0:
return Const(1)
if isinstance(right, Const) and right.value == 1:
return left
if isinstance(left, Const) and isinstance(right, Const):
return Const(left.value ** right.value)
return BinOp(left, "**", right)
return BinOp(left, expr.op, right)
if isinstance(expr, Func):
arg_canon = canonical(expr.arg)
if isinstance(arg_canon, Const) and arg_canon.value is not None:
try:
result = _eval_func(expr.op, arg_canon.value)
return Const(result)
except (ValueError, TypeError):
pass
return Func(arg_canon, expr.op)
return expr
def _expand_power(expr):
if not isinstance(expr, BinOp) or expr.op != "**":
return expr
base = expr.left
exponent = expr.right
if isinstance(exponent, Const) and isinstance(exponent.value, (int, float)):
n = int(exponent.value)
if n == 2:
return canonical(BinOp(base, "*", base))
if n >= 2 and n <= 10:
result = base
for _ in range(n - 1):
result = canonical(BinOp(result, "*", base))
return result
return expr
def _get_product_vars(expr):
if isinstance(expr, Var):
return [expr.name]
if isinstance(expr, BinOp) and expr.op == "*":
return _get_product_vars(expr.left) + _get_product_vars(expr.right)
return []
def _get_sorted_vars_key(expr):
"""Get a canonical key (sorted tuple) for variable comparison in products."""
return tuple(sorted(_get_product_vars(expr)))
def _get_const_from_term(expr):
if isinstance(expr, Const):
return expr.value if expr.value is not None else 0.0
if isinstance(expr, BinOp) and expr.op == "*":
consts = []
if isinstance(expr.left, Const) and expr.left.value is not None:
consts.append(expr.left.value)
if isinstance(expr.right, Const) and expr.right.value is not None:
consts.append(expr.right.value)
if consts:
result = 1.0
for c in consts:
result *= c
return result
return 1.0
def _get_non_const_part(expr):
if isinstance(expr, Const):
return Const(1)
if isinstance(expr, Var):
return expr
if isinstance(expr, BinOp) and expr.op == "*":
parts = []
if isinstance(expr.left, Const):
pass
else:
parts.append(_get_non_const_part(expr.left))
if isinstance(expr.right, Const):
pass
else:
parts.append(_get_non_const_part(expr.right))
if len(parts) == 1:
return parts[0]
if len(parts) == 2:
return BinOp(parts[0], "*", parts[1])
return Const(1)
return expr
def _expand_product_sum(expr):
if not isinstance(expr, BinOp):
return expr
left = expr.left
right = expr.right
op = expr.op
if op == "**":
if isinstance(right, Const) and isinstance(right.value, (int, float)):
n = int(right.value)
if n == 2 and isinstance(left, BinOp) and left.op == "+":
a = left.left
b = left.right
a_sq = _expand_power(BinOp(a, "**", Const(2)))
b_sq = _expand_power(BinOp(b, "**", Const(2)))
ab = BinOp(a, "*", b)
ba = BinOp(b, "*", a)
two_ab = BinOp(Const(2), "*", ab)
result = BinOp(BinOp(a_sq, "+", two_ab), "+", b_sq)
return _combine_like_terms(result)
if n >= 2 and n <= 10 and isinstance(left, BinOp) and left.op == "+":
result = left
for _ in range(n - 1):
result = BinOp(result, "*", left)
return _combine_like_terms(expand(result))
if op == "*":
left_exp = _expand_product_sum(left)
right_exp = _expand_product_sum(right)
if isinstance(left_exp, BinOp) and left_exp.op == "+":
term1 = BinOp(left_exp.left, "*", right_exp)
term2 = BinOp(left_exp.right, "*", right_exp)
result = BinOp(term1, "+", term2)
return _combine_like_terms(result)
if isinstance(right_exp, BinOp) and right_exp.op == "+":
term1 = BinOp(left_exp, "*", right_exp.left)
term2 = BinOp(left_exp, "*", right_exp.right)
result = BinOp(term1, "+", term2)
return _combine_like_terms(result)
if isinstance(left_exp, BinOp) and left_exp.op == "-":
term1 = BinOp(left_exp.left, "*", right_exp)
term2 = BinOp(left_exp.right, "*", right_exp)
return canonical(BinOp(term1, "-", term2))
if isinstance(right_exp, BinOp) and right_exp.op == "-":
term1 = BinOp(left_exp, "*", right_exp.left)
term2 = BinOp(left_exp, "*", right_exp.right)
return canonical(BinOp(term1, "-", term2))
return expr
def _combine_like_terms(expr):
"""Combine like terms in a sum by collecting all terms first."""
if not isinstance(expr, BinOp) or expr.op != "+":
return expr
all_terms = _collect_all_terms(expr)
return _merge_terms_list(all_terms)
def _collect_all_terms(expr):
"""Collect all terms from a sum into a flat list."""
if isinstance(expr, BinOp) and expr.op == "+":
return _collect_all_terms(expr.left) + _collect_all_terms(expr.right)
return [expr]
def _merge_terms_list(terms):
"""Merge like terms from a flat list of terms."""
term_map = {}
for term in terms:
key = _get_sorted_vars_key(term)
const = _get_const_from_term(term) or 0.0
if key in term_map:
existing_const, _ = term_map[key]
term_map[key] = (existing_const + const, term)
else:
term_map[key] = (const, term)
if not term_map:
return Const(0)
merged = []
for const, term in term_map.values():
if const == 0:
continue
non_const = _get_non_const_part(term)
if isinstance(non_const, Const):
merged.append(Const(const))
elif const == 1.0:
merged.append(non_const)
else:
merged.append(BinOp(Const(const), "*", non_const))
merged.sort(key=lambda e: _hash_expr(e))
result = merged[0]
for t in merged[1:]:
result = BinOp(result, "+", t)
return result
def _try_expand(expr):
if isinstance(expr, BinOp):
left_expanded = _try_expand(expr.left)
right_expanded = _try_expand(expr.right)
new_expr = BinOp(left_expanded, expr.op, right_expanded)
expanded = _expand_product_sum(new_expr)
if expanded is not new_expr:
return _combine_like_terms(expanded)
if new_expr.op == "**":
pow_expanded = _expand_power(new_expr)
if pow_expanded is not new_expr:
return _combine_like_terms(pow_expanded)
return _combine_like_terms(new_expr)
if isinstance(expr, UnOp):
operand_expanded = _try_expand(expr.operand)
if operand_expanded is not expr.operand:
return UnOp(expr.op, operand_expanded)
return expr
if isinstance(expr, Func):
arg_expanded = _try_expand(expr.arg)
if arg_expanded is not expr.arg:
return Func(arg_expanded, expr.op)
return expr
return expr
[docs]
def expand(expr):
result = expr
for _ in range(20):
expanded = _try_expand(result)
if str(expanded) == str(result):
break
result = expanded
return result
[docs]
def expand_and_canonicalize(expr):
expanded = expand(expr)
result = expanded
for _ in range(20):
canonicalized = canonical(result)
if str(canonicalized) == str(result):
break
result = canonicalized
expanded2 = expand(result)
if str(expanded2) == str(result):
break
result = expanded2
return result
[docs]
def structural_hash(expr):
canon = canonical(expr)
return _hash_expr(canon)
[docs]
def structural_hash_expanded(expr):
canon = expand_and_canonicalize(expr)
return _hash_expr(canon)
[docs]
def equivalent(e1, e2):
if e1 is e2:
return True
h1 = structural_hash(e1)
h2 = structural_hash(e2)
return h1 == h2
[docs]
def equivalent_expanded(e1, e2):
if e1 is e2:
return True
h1 = structural_hash_expanded(e1)
h2 = structural_hash_expanded(e2)
return h1 == h2
[docs]
def equivalent_systems(sys1, sys2):
if len(sys1) != len(sys2):
return False
hashes1 = sorted(structural_hash(e) for e in sys1)
hashes2 = sorted(structural_hash(e) for e in sys2)
return hashes1 == hashes2
[docs]
def equations_equivalent(eq1, eq2):
if not isinstance(eq1, BinOp) or not isinstance(eq2, BinOp):
return structural_hash(eq1) == structural_hash(eq2)
lhs1_hash = structural_hash(eq1.left)
lhs2_hash = structural_hash(eq2.left)
rhs1_hash = structural_hash(eq1.right)
rhs2_hash = structural_hash(eq2.right)
if lhs1_hash == lhs2_hash and rhs1_hash == rhs2_hash:
return True
if lhs1_hash == rhs2_hash and rhs1_hash == lhs2_hash:
return True
diff1 = canonical(BinOp(eq1.left, "-", eq1.right))
diff2 = canonical(BinOp(eq2.left, "-", eq2.right))
return structural_hash(diff1) == structural_hash(diff2)
[docs]
def test_commutativity():
"""Test a + b == b + a"""
a = Var("a")
b = Var("b")
expr1 = a + b
expr2 = b + a
h1 = structural_hash(expr1)
h2 = structural_hash(expr2)
print(f"Test commutativity (add):")
print(f" a + b hash: {h1[:16]}...")
print(f" b + a hash: {h2[:16]}...")
assert h1 == h2, f"Hashes should match: {h1} != {h2}"
assert equivalent(expr1, expr2), "a + b should be equivalent to b + a"
print(" PASSED\n")
[docs]
def test_commutativity_mult():
"""Test a * b == b * a"""
a = Var("a")
b = Var("b")
expr1 = a * b
expr2 = b * a
h1 = structural_hash(expr1)
h2 = structural_hash(expr2)
print(f"Test commutativity (mult):")
print(f" a * b hash: {h1[:16]}...")
print(f" b * a hash: {h2[:16]}...")
assert h1 == h2, f"Hashes should match: {h1} != {h2}"
assert equivalent(expr1, expr2), "a * b should be equivalent to b * a"
print(" PASSED\n")
[docs]
def test_associativity():
"""Test (a + b) + c == a + (b + c)"""
a = Var("a")
b = Var("b")
c = Var("c")
expr1 = (a + b) + c
expr2 = a + (b + c)
print(f"Test associativity (add):")
print(f" (a + b) + c: {expr1}")
print(f" a + (b + c): {expr2}")
h1 = structural_hash(expr1)
h2 = structural_hash(expr2)
print(f" hash1: {h1[:16]}...")
print(f" hash2: {h2[:16]}...")
assert h1 == h2, f"Hashes should match: {h1} != {h2}"
assert equivalent(expr1, expr2), "(a+b)+c should be equivalent to a+(b+c)"
print(" PASSED\n")
[docs]
def test_associativity_mult():
"""Test (a * b) * c == a * (b * c)"""
a = Var("a")
b = Var("b")
c = Var("c")
expr1 = (a * b) * c
expr2 = a * (b * c)
print(f"Test associativity (mult):")
print(f" (a * b) * c: {expr1}")
print(f" a * (b * c): {expr2}")
h1 = structural_hash(expr1)
h2 = structural_hash(expr2)
assert h1 == h2, f"Hashes should match: {h1} != {h2}"
assert equivalent(expr1, expr2), "(a*b)*c should be equivalent to a*(b*c)"
print(" PASSED\n")
[docs]
def test_polynomial_equivalence():
"""Test (x + 1)^2 == x^2 + 2x + 1 using expanded comparison"""
x = Var("x")
expr1 = (x + Const(1)) ** Const(2)
expr2 = x ** Const(2) + Const(2) * x + Const(1)
print(f"Test polynomial equivalence:")
print(f" (x + 1)^2: {expr1}")
print(f" x^2 + 2x + 1: {expr2}")
print(f" (x + 1)^2 expanded: {expand(expr1)}")
print(f" x^2 + 2x + 1 expanded: {expand(expr2)}")
print(f" (x + 1)^2 expand+canon: {expand_and_canonicalize(expr1)}")
print(f" x^2 + 2x + 1 expand+canon: {expand_and_canonicalize(expr2)}")
h1 = structural_hash_expanded(expr1)
h2 = structural_hash_expanded(expr2)
print(f" expanded hash1: {h1[:16]}...")
print(f" expanded hash2: {h2[:16]}...")
assert equivalent_expanded(expr1, expr2), "(x+1)^2 should equal x^2+2x+1 with expansion"
print(" PASSED\n")
[docs]
def test_function_stability():
"""Test sin(x + 0) == sin(x)"""
x = Var("x")
expr1 = sin(x + Const(0))
expr2 = sin(x)
print(f"Test function stability:")
print(f" sin(x + 0): {canonical(expr1)}")
print(f" sin(x): {canonical(expr2)}")
h1 = structural_hash(expr1)
h2 = structural_hash(expr2)
assert h1 == h2, f"sin(x+0) should hash same as sin(x)"
assert equivalent(expr1, expr2), "sin(x+0) should be equivalent to sin(x)"
print(" PASSED\n")
[docs]
def test_complex_expression():
"""Test more complex expressions"""
x = Var("x")
y = Var("y")
expr1 = (x + y) * (x + y)
expr2 = x * x + Const(2) * x * y + y * y
print(f"Test complex expression:")
print(f" (x + y)^2: {expr1}")
print(f" x^2 + 2xy + y^2: {expr2}")
print(f" (x + y)^2 expanded: {expand(expr1)}")
print(f" x^2 + 2xy + y^2 expanded: {expand(expr2)}")
print(f" (x + y)^2 expand+canon: {expand_and_canonicalize(expr1)}")
print(f" x^2 + 2xy + y^2 expand+canon: {expand_and_canonicalize(expr2)}")
h1 = structural_hash_expanded(expr1)
h2 = structural_hash_expanded(expr2)
print(f" expanded hash1: {h1[:16]}...")
print(f" expanded hash2: {h2[:16]}...")
assert equivalent_expanded(expr1, expr2), "(x+y)^2 should equal x^2+2xy+y^2 with expansion"
print(" PASSED\n")
[docs]
def test_equations_equivalent():
"""Test equation equivalence considering a=b == b=a"""
x = Var("x")
eq1 = x - Const(1)
eq2 = Const(1) - x
print(f"Test equations_equivalent:")
print(f" eq1: x - 1")
print(f" eq2: 1 - x")
print(f" eq1 canonical: {canonical(eq1)}")
print(f" eq2 canonical: {canonical(eq2)}")
h1 = structural_hash(canonical(eq1))
h2 = structural_hash(canonical(eq2))
print(f" hash1: {h1[:16]}...")
print(f" hash2: {h2[:16]}...")
assert equations_equivalent(eq1, eq2), "x-1 should be equivalent to 1-x as equations"
print(" PASSED\n")
[docs]
def test_system_equivalence():
"""Test that two systems with equations in different order are equivalent"""
x = Var("x")
y = Var("y")
sys1 = [
x + y,
x * y,
x - y,
]
sys2 = [
x * y,
x - y,
x + y,
]
print(f"Test system equivalence:")
print(f" sys1 hashes: {[structural_hash(e)[:8] for e in sys1]}")
print(f" sys2 hashes: {[structural_hash(e)[:8] for e in sys2]}")
assert equivalent_systems(sys1, sys2), "Systems should be equivalent regardless of equation order"
print(" PASSED\n")
[docs]
def test_constant_folding():
"""Test constant folding"""
a = Var("a")
expr1 = Const(2) + Const(3)
expr2 = Const(5)
print(f"Test constant folding:")
print(f" 2 + 3 canonical: {canonical(expr1)}")
print(f" 5 canonical: {canonical(expr2)}")
h1 = structural_hash(expr1)
h2 = structural_hash(expr2)
assert h1 == h2, "2+3 should hash same as 5"
assert equivalent(expr1, expr2), "2+3 should be equivalent to 5"
print(" PASSED\n")
[docs]
def test_neutral_elements():
"""Test a + 0 = a, a * 1 = a"""
a = Var("a")
print(f"Test neutral elements:")
expr1 = a + Const(0)
expr2 = a
h1 = structural_hash(expr1)
h2 = structural_hash(expr2)
print(f" a + 0 hash: {h1[:16]}...")
print(f" a hash: {h2[:16]}...")
assert equivalent(expr1, expr2), "a + 0 should be equivalent to a"
expr3 = a * Const(1)
expr4 = a
h3 = structural_hash(expr3)
h4 = structural_hash(expr4)
print(f" a * 1 hash: {h3[:16]}...")
print(f" a hash: {h4[:16]}...")
assert equivalent(expr3, expr4), "a * 1 should be equivalent to a"
print(" PASSED\n")
[docs]
def test_zero_product():
"""Test a * 0 = 0"""
a = Var("a")
print(f"Test zero product:")
expr1 = a * Const(0)
expr2 = Const(0)
h1 = structural_hash(expr1)
h2 = structural_hash(expr2)
print(f" a * 0 hash: {h1[:16]}...")
print(f" 0 hash: {h2[:16]}...")
assert equivalent(expr1, expr2), "a * 0 should be equivalent to 0"
print(" PASSED\n")
[docs]
def test_power_normalization():
"""Test a**0 = 1, a**1 = a"""
a = Var("a")
print(f"Test power normalization:")
expr1 = a ** Const(0)
expr2 = Const(1)
h1 = structural_hash(expr1)
h2 = structural_hash(expr2)
print(f" a**0 hash: {h1[:16]}...")
print(f" 1 hash: {h2[:16]}...")
assert equivalent(expr1, expr2), "a**0 should be equivalent to 1"
expr3 = a ** Const(1)
h3 = structural_hash(expr3)
h4 = structural_hash(a)
print(f" a**1 hash: {h3[:16]}...")
print(f" a hash: {h4[:16]}...")
assert equivalent(expr3, a), "a**1 should be equivalent to a"
print(" PASSED\n")
[docs]
def run_all_tests():
print("=" * 60)
print("RUNNING SYMBOLIC EQUIVALENCE TESTS")
print("=" * 60 + "\n")
tests = [
test_commutativity,
test_commutativity_mult,
test_associativity,
test_associativity_mult,
test_polynomial_equivalence,
test_function_stability,
test_complex_expression,
test_equations_equivalent,
test_system_equivalence,
test_constant_folding,
test_neutral_elements,
test_zero_product,
test_power_normalization,
]
passed = 0
failed = 0
for test in tests:
try:
test()
passed += 1
except AssertionError as e:
print(f" FAILED: {e}\n")
failed += 1
except Exception as e:
import traceback
print(f" ERROR: {e}")
traceback.print_exc()
print()
failed += 1
print("=" * 60)
print(f"RESULTS: {passed} passed, {failed} failed")
print("=" * 60)
return failed == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)