# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
"""
Standalone tests for VariableAlignmentEngine.
This test file imports ONLY the minimal classes needed.
"""
import sys
import os
import hashlib
from typing import Any, Dict, List, Optional, Set, Tuple
sys.path.insert(0, '/home/marina/PycharmProjects/VeraGrid/src')
[docs]
class Const:
__slots__ = ("value", "name", "uid")
def __init__(self, value=None, uid=None, name=""):
self.uid = uid if uid is not None else id(self)
self.value = value
self.name = name
def __str__(self):
return str(self.value)
def __repr__(self):
return self.__str__()
[docs]
class Var:
__slots__ = ("name", "uid")
def __init__(self, name, uid=None):
self.name = name
self.uid = uid if uid is not None else id(self)
def __str__(self):
return self.name
def __repr__(self):
return self.name
[docs]
class BinOp:
__slots__ = ("op", "left", "right", "uid")
def __init__(self, left, op, right, uid=None):
self.op = op
self.left = left
self.right = right
self.uid = uid if uid is not None else id(self)
def __str__(self):
return f"({self.left} {self.op} {self.right})"
def __repr__(self):
return self.__str__()
[docs]
class UnOp:
__slots__ = ("op", "operand", "uid")
def __init__(self, op, operand, uid=None):
self.op = op
self.operand = operand
self.uid = uid if uid is not None else id(self)
def __str__(self):
return f"{self.op}({self.operand})"
def __repr__(self):
return self.__str__()
[docs]
class Func:
__slots__ = ("op", "arg", "uid")
def __init__(self, arg, op="", uid=None):
self.op = op
self.arg = arg
self.uid = uid if uid is not None else id(self)
def __str__(self):
return f"{self.op}({self.arg})"
def __repr__(self):
return self.__str__()
[docs]
class Func2:
__slots__ = ("name", "arg1", "arg2", "uid")
def __init__(self, name, arg1, arg2, uid=None):
self.name = name
self.arg1 = arg1
self.arg2 = arg2
self.uid = uid if uid is not None else id(self)
def __str__(self):
return f"{self.name}({self.arg1}, {self.arg2})"
def __repr__(self):
return self.__str__()
def _to_expr(val):
if isinstance(val, (int, float)):
return Const(val)
return val
def _add(self, other):
return BinOp(self, "+", _to_expr(other))
def _radd(self, other):
return BinOp(_to_expr(other), "+", self)
def _sub(self, other):
return BinOp(self, "-", _to_expr(other))
def _rsub(self, other):
return BinOp(_to_expr(other), "-", self)
def _mul(self, other):
return BinOp(self, "*", _to_expr(other))
def _rmul(self, other):
return BinOp(_to_expr(other), "*", self)
def _truediv(self, other):
return BinOp(self, "/", _to_expr(other))
def _rtruediv(self, other):
return BinOp(_to_expr(other), "/", self)
def _pow(self, other):
return BinOp(self, "**", _to_expr(other))
def _rpow(self, other):
return BinOp(_to_expr(other), "**", self)
def _neg(self):
return UnOp("-", self)
Const.__add__ = _add
Const.__radd__ = _radd
Const.__sub__ = _sub
Const.__rsub__ = _rsub
Const.__mul__ = _mul
Const.__rmul__ = _rmul
Const.__truediv__ = _truediv
Const.__rtruediv__ = _rtruediv
Const.__pow__ = _pow
Const.__rpow__ = _rpow
Const.__neg__ = _neg
Var.__add__ = _add
Var.__radd__ = _radd
Var.__sub__ = _sub
Var.__rsub__ = _rsub
Var.__mul__ = _mul
Var.__rmul__ = _rmul
Var.__truediv__ = _truediv
Var.__rtruediv__ = _rtruediv
Var.__pow__ = _pow
Var.__rpow__ = _rpow
Var.__neg__ = _neg
BinOp.__add__ = _add
BinOp.__radd__ = _radd
BinOp.__sub__ = _sub
BinOp.__rsub__ = _rsub
BinOp.__mul__ = _mul
BinOp.__rmul__ = _rmul
BinOp.__truediv__ = _truediv
BinOp.__rtruediv__ = _rtruediv
BinOp.__pow__ = _pow
BinOp.__rpow__ = _rpow
BinOp.__neg__ = _neg
UnOp.__add__ = _add
UnOp.__radd__ = _radd
UnOp.__sub__ = _sub
UnOp.__rsub__ = _rsub
UnOp.__mul__ = _mul
UnOp.__rmul__ = _rmul
UnOp.__truediv__ = _truediv
UnOp.__rtruediv__ = _rtruediv
UnOp.__pow__ = _pow
UnOp.__rpow__ = _rpow
UnOp.__neg__ = _neg
CommutativeOps = {"+", "*"}
def _canonical_const(value):
if value is None:
return ("none", None)
if isinstance(value, complex):
return ("complex", (float(value.real), float(value.imag)))
if isinstance(value, float):
if value == int(value):
return ("int", int(value))
return ("float", value)
return (type(value).__name__, value)
def _flatten_assoc(expr, op):
if isinstance(expr, BinOp) and expr.op == op:
return _flatten_assoc(expr.left, op) + _flatten_assoc(expr.right, op)
return [expr]
def _sort_exprs(exprs, key_func):
return sorted(exprs, key=key_func)
def _hash_expr(expr):
if isinstance(expr, Const):
type_key, norm_val = _canonical_const(expr.value)
h = hashlib.sha256()
h.update(f"Const:{type_key}:{norm_val}".encode())
return h.hexdigest()
if isinstance(expr, Var):
h = hashlib.sha256()
h.update(f"Var:{expr.name}".encode())
return h.hexdigest()
if isinstance(expr, UnOp):
h = hashlib.sha256()
h.update(f"UnOp:{expr.op}:{_hash_expr(expr.operand)}".encode())
return h.hexdigest()
if isinstance(expr, BinOp):
h = hashlib.sha256()
child_hash = _hash_expr(expr.left) + _hash_expr(expr.right)
if expr.op in CommutativeOps:
child_hash = "".join(sorted(child_hash))
h.update(f"BinOp:{expr.op}:{child_hash}".encode())
return h.hexdigest()
if isinstance(expr, Func):
h = hashlib.sha256()
h.update(f"Func:{expr.op}:{_hash_expr(expr.arg)}".encode())
return h.hexdigest()
if isinstance(expr, Func2):
h = hashlib.sha256()
arg_hash = _hash_expr(expr.arg1) + _hash_expr(expr.arg2)
if expr.name in {"min", "max", "atan2"}:
arg_hash = "".join(sorted([_hash_expr(expr.arg1), _hash_expr(expr.arg2)]))
h.update(f"Func2:{expr.name}:{arg_hash}".encode())
return h.hexdigest()
raise TypeError(f"Unknown expression type: {type(expr)}")
def _build_sum(terms):
sorted_terms = _sort_exprs(terms, _hash_expr)
result = sorted_terms[0]
for t in sorted_terms[1:]:
result = BinOp(result, "+", t)
return result
def _build_product(factors):
sorted_factors = _sort_exprs(factors, _hash_expr)
result = sorted_factors[0]
for f in sorted_factors[1:]:
result = BinOp(result, "*", f)
return result
def _eval_func(op, value):
import math
if op == "sin":
return math.sin(value)
if op == "cos":
return math.cos(value)
if op == "tan":
return math.tan(value)
if op == "exp":
return math.exp(value)
if op == "log":
return math.log(value)
if op == "sqrt":
return math.sqrt(value)
if op == "abs":
return abs(value)
if op == "floor":
return math.floor(value)
if op == "ceil":
return math.ceil(value)
if op == "heaviside":
return 0.0 if value <= 0 else 1.0
raise ValueError(f"Cannot evaluate function {op}")
[docs]
def canonical(expr):
if isinstance(expr, Const):
return Const(expr.value)
if isinstance(expr, Var):
return Var(expr.name)
if isinstance(expr, UnOp):
canonical_operand = canonical(expr.operand)
if expr.op == "-":
if isinstance(canonical_operand, Const):
return Const(-canonical_operand.value)
if isinstance(canonical_operand, UnOp) and canonical_operand.op == "-":
return canonical_operand.operand
return UnOp("-", canonical_operand)
return UnOp(expr.op, canonical_operand)
if isinstance(expr, BinOp):
left = canonical(expr.left)
right = canonical(expr.right)
if expr.op == "+":
terms = _flatten_assoc(BinOp(left, "+", right), "+")
const_sum = 0.0
non_const_terms = []
for t in terms:
if isinstance(t, Const) and t.value is not None:
const_sum += t.value
else:
non_const_terms.append(t)
if const_sum != 0.0 or not non_const_terms:
non_const_terms.append(Const(const_sum))
if len(non_const_terms) == 1:
return non_const_terms[0]
return _build_sum(non_const_terms)
if expr.op == "*":
factors = _flatten_assoc(BinOp(left, "*", right), "*")
zero_found = False
const_product = 1.0
non_const_factors = []
for f in factors:
if isinstance(f, Const) and f.value is not None:
if f.value == 0:
zero_found = True
break
const_product *= f.value
else:
non_const_factors.append(f)
if zero_found:
return Const(0)
if const_product != 1.0:
non_const_factors.append(Const(const_product))
if not non_const_factors:
return Const(1)
if len(non_const_factors) == 1:
return non_const_factors[0]
return _build_product(non_const_factors)
if expr.op == "-":
if isinstance(right, Const) and isinstance(left, Const):
return Const(left.value - right.value)
if isinstance(right, Const) and right.value == 0:
return left
if isinstance(left, Const) and left.value == 0:
return UnOp("-", right)
return BinOp(left, "-", right)
if expr.op == "/":
if isinstance(left, Const) and isinstance(right, Const):
if right.value != 0:
return Const(left.value / right.value)
if isinstance(right, Const) and right.value == 1:
return left
if isinstance(left, Const) and left.value == 0:
return Const(0)
return BinOp(left, "/", right)
if expr.op == "**":
if isinstance(right, Const) and right.value == 0:
return Const(1)
if isinstance(right, Const) and right.value == 1:
return left
if isinstance(left, Const) and isinstance(right, Const):
return Const(left.value ** right.value)
return BinOp(left, "**", right)
return BinOp(left, expr.op, right)
if isinstance(expr, Func):
arg_canon = canonical(expr.arg)
if isinstance(arg_canon, Const) and arg_canon.value is not None:
try:
result = _eval_func(expr.op, arg_canon.value)
return Const(result)
except (ValueError, TypeError):
pass
return Func(arg_canon, expr.op)
return expr
[docs]
def structural_hash(expr):
canon = canonical(expr)
return _hash_expr(canon)
def _get_product_vars(expr):
if isinstance(expr, Var):
return [expr.name]
if isinstance(expr, BinOp) and expr.op == "*":
return _get_product_vars(expr.left) + _get_product_vars(expr.right)
return []
def _get_sorted_vars_key(expr):
return tuple(sorted(_get_product_vars(expr)))
def _get_const_from_term(expr):
if isinstance(expr, Const):
return expr.value if expr.value is not None else 0.0
if isinstance(expr, BinOp) and expr.op == "*":
consts = []
if isinstance(expr.left, Const) and expr.left.value is not None:
consts.append(expr.left.value)
if isinstance(expr.right, Const) and expr.right.value is not None:
consts.append(expr.right.value)
if consts:
result = 1.0
for c in consts:
result *= c
return result
return 1.0
def _get_non_const_part(expr):
if isinstance(expr, Const):
return Const(1)
if isinstance(expr, Var):
return expr
if isinstance(expr, BinOp) and expr.op == "*":
parts = []
if isinstance(expr.left, Const):
pass
else:
parts.append(_get_non_const_part(expr.left))
if isinstance(expr.right, Const):
pass
else:
parts.append(_get_non_const_part(expr.right))
if len(parts) == 1:
return parts[0]
if len(parts) == 2:
return BinOp(parts[0], "*", parts[1])
return Const(1)
return expr
def _expand_power(expr):
if not isinstance(expr, BinOp) or expr.op != "**":
return expr
base = expr.left
exponent = expr.right
if isinstance(exponent, Const) and isinstance(exponent.value, (int, float)):
n = int(exponent.value)
if n == 2:
return canonical(BinOp(base, "*", base))
if n >= 2 and n <= 10:
result = base
for _ in range(n - 1):
result = canonical(BinOp(result, "*", base))
return result
return expr
def _collect_all_terms(expr):
if isinstance(expr, BinOp) and expr.op == "+":
return _collect_all_terms(expr.left) + _collect_all_terms(expr.right)
return [expr]
def _merge_terms_list(terms):
term_map = {}
for term in terms:
key = _get_sorted_vars_key(term)
const = _get_const_from_term(term) or 0.0
if key in term_map:
existing_const, _ = term_map[key]
term_map[key] = (existing_const + const, term)
else:
term_map[key] = (const, term)
if not term_map:
return Const(0)
merged = []
for const, term in term_map.values():
if const == 0:
continue
non_const = _get_non_const_part(term)
if isinstance(non_const, Const):
merged.append(Const(const))
elif const == 1.0:
merged.append(non_const)
else:
merged.append(BinOp(Const(const), "*", non_const))
merged.sort(key=_hash_expr)
result = merged[0]
for t in merged[1:]:
result = BinOp(result, "+", t)
return result
def _combine_like_terms(expr):
if not isinstance(expr, BinOp) or expr.op != "+":
return expr
all_terms = _collect_all_terms(expr)
return _merge_terms_list(all_terms)
def _expand_product_sum(expr):
if not isinstance(expr, BinOp):
return expr
left = expr.left
right = expr.right
op = expr.op
if op == "**":
if isinstance(right, Const) and isinstance(right.value, (int, float)):
n = int(right.value)
if n == 2 and isinstance(left, BinOp) and left.op == "+":
a = left.left
b = left.right
a_sq = _expand_power(BinOp(a, "**", Const(2)))
b_sq = _expand_power(BinOp(b, "**", Const(2)))
ab = BinOp(a, "*", b)
ba = BinOp(b, "*", a)
two_ab = BinOp(Const(2), "*", ab)
result = BinOp(BinOp(a_sq, "+", two_ab), "+", b_sq)
return _combine_like_terms(result)
if n >= 2 and n <= 10 and isinstance(left, BinOp) and left.op == "+":
result = left
for _ in range(n - 1):
result = BinOp(result, "*", left)
return _combine_like_terms(expand(result))
if op == "*":
left_exp = _expand_product_sum(left)
right_exp = _expand_product_sum(right)
if isinstance(left_exp, BinOp) and left_exp.op == "+":
term1 = BinOp(left_exp.left, "*", right_exp)
term2 = BinOp(left_exp.right, "*", right_exp)
result = BinOp(term1, "+", term2)
return _combine_like_terms(result)
if isinstance(right_exp, BinOp) and right_exp.op == "+":
term1 = BinOp(left_exp, "*", right_exp.left)
term2 = BinOp(left_exp, "*", right_exp.right)
result = BinOp(term1, "+", term2)
return _combine_like_terms(result)
if isinstance(left_exp, BinOp) and left_exp.op == "-":
term1 = BinOp(left_exp.left, "*", right_exp)
term2 = BinOp(left_exp.right, "*", right_exp)
return canonical(BinOp(term1, "-", term2))
if isinstance(right_exp, BinOp) and right_exp.op == "-":
term1 = BinOp(left_exp, "*", right_exp.left)
term2 = BinOp(left_exp, "*", right_exp.right)
return canonical(BinOp(term1, "-", term2))
return expr
def _try_expand(expr):
if isinstance(expr, BinOp):
left_expanded = _try_expand(expr.left)
right_expanded = _try_expand(expr.right)
new_expr = BinOp(left_expanded, expr.op, right_expanded)
expanded = _expand_product_sum(new_expr)
if expanded is not new_expr:
return _combine_like_terms(expanded)
if new_expr.op == "**":
pow_expanded = _expand_power(new_expr)
if pow_expanded is not new_expr:
return _combine_like_terms(pow_expanded)
return _combine_like_terms(new_expr)
if isinstance(expr, UnOp):
operand_expanded = _try_expand(expr.operand)
if operand_expanded is not expr.operand:
return UnOp(expr.op, operand_expanded)
return expr
if isinstance(expr, Func):
arg_expanded = _try_expand(expr.arg)
if arg_expanded is not expr.arg:
return Func(arg_expanded, expr.op)
return expr
return expr
[docs]
def expand(expr):
result = expr
for _ in range(20):
expanded = _try_expand(result)
if str(expanded) == str(result):
break
result = expanded
return result
[docs]
def expand_and_canonicalize(expr):
expanded = expand(expr)
result = expanded
for _ in range(20):
canonicalized = canonical(result)
if str(canonicalized) == str(result):
break
result = canonicalized
expanded2 = expand(result)
if str(expanded2) == str(result):
break
result = expanded2
return result
[docs]
class DAGNode:
__slots__ = ("op", "children", "structural_hash", "_expr")
def __init__(self, op, children, structural_hash, expr=None):
self.op = op
self.children = children
self.structural_hash = structural_hash
self._expr = expr
[docs]
def to_dag(expr, memo=None):
if memo is None:
memo = {}
canon = canonical(expr)
return _to_dag_recursive(canon, memo)
def _to_dag_recursive(expr, memo):
h = _hash_expr(expr)
if h in memo:
return memo[h]
if isinstance(expr, Const):
node = DAGNode(op="Const", children=[], structural_hash=h, expr=expr)
memo[h] = node
return node
if isinstance(expr, Var):
node = DAGNode(op=f"Var:{expr.name}", children=[], structural_hash=h, expr=expr)
memo[h] = node
return node
if isinstance(expr, UnOp):
child = _to_dag_recursive(expr.operand, memo)
node = DAGNode(op=f"UnOp:{expr.op}", children=[child], structural_hash=h, expr=expr)
memo[h] = node
return node
if isinstance(expr, BinOp):
left = _to_dag_recursive(expr.left, memo)
right = _to_dag_recursive(expr.right, memo)
children = [left, right]
if expr.op in CommutativeOps:
children = sorted(children, key=lambda n: n.structural_hash)
node = DAGNode(op=f"BinOp:{expr.op}", children=children, structural_hash=h, expr=expr)
memo[h] = node
return node
if isinstance(expr, Func):
arg = _to_dag_recursive(expr.arg, memo)
node = DAGNode(op=f"Func:{expr.op}", children=[arg], structural_hash=h, expr=expr)
memo[h] = node
return node
if isinstance(expr, Func2):
arg1 = _to_dag_recursive(expr.arg1, memo)
arg2 = _to_dag_recursive(expr.arg2, memo)
children = [arg1, arg2]
if expr.name in {"min", "max", "atan2"}:
children = sorted(children, key=lambda n: n.structural_hash)
node = DAGNode(op=f"Func2:{expr.name}", children=children, structural_hash=h, expr=expr)
memo[h] = node
return node
raise TypeError(f"Unknown expression type: {type(expr)}")
PLACEHOLDER_NAME = "_"
[docs]
class VariableSignature:
__slots__ = ("var_uid", "occurrence_idx", "context_hash", "depth", "path_signature")
def __init__(self, var_uid, occurrence_idx, context_hash, depth, path_signature):
self.var_uid = var_uid
self.occurrence_idx = occurrence_idx
self.context_hash = context_hash
self.depth = depth
self.path_signature = path_signature
def __hash__(self):
return hash((self.var_uid, self.occurrence_idx, self.context_hash, self.depth, self.path_signature))
def __eq__(self, other):
if not isinstance(other, VariableSignature):
return False
return (self.occurrence_idx == other.occurrence_idx
and self.context_hash == other.context_hash
and self.depth == other.depth
and self.path_signature == other.path_signature)
def __repr__(self):
return f"VarSig(uid={self.var_uid}, occ={self.occurrence_idx}, ctx={self.context_hash[:6]}, depth={self.depth})"
def _replace_var_at_position(expr, target_uid, placeholder, occurrence_counter):
counter = occurrence_counter[0]
if isinstance(expr, Var):
if expr.uid == target_uid:
occurrence_counter[0] += 1
return placeholder, counter
return expr, -1
if isinstance(expr, BinOp):
new_left, idx = _replace_var_at_position(expr.left, target_uid, placeholder, occurrence_counter)
new_right, idx2 = _replace_var_at_position(expr.right, target_uid, placeholder, occurrence_counter)
if idx == -1 and idx2 != -1:
idx = idx2
return BinOp(new_left, expr.op, new_right), idx
if isinstance(expr, UnOp):
new_operand, idx = _replace_var_at_position(expr.operand, target_uid, placeholder, occurrence_counter)
return UnOp(expr.op, new_operand), idx
if isinstance(expr, Func):
new_arg, idx = _replace_var_at_position(expr.arg, target_uid, placeholder, occurrence_counter)
return Func(new_arg, expr.op), idx
if isinstance(expr, Func2):
new_arg1, idx1 = _replace_var_at_position(expr.arg1, target_uid, placeholder, occurrence_counter)
new_arg2, idx2 = _replace_var_at_position(expr.arg2, target_uid, placeholder, occurrence_counter)
idx = idx1 if idx1 != -1 else idx2
return Func2(expr.name, new_arg1, new_arg2), idx
return expr, -1
def _get_var_occurrences(expr, var_uid):
occurrences = []
_collect_occurrences(expr, var_uid, occurrences)
return occurrences
def _collect_occurrences(expr, var_uid, result):
if isinstance(expr, Var):
if expr.uid == var_uid:
result.append(len(result))
return
if isinstance(expr, BinOp):
_collect_occurrences(expr.left, var_uid, result)
_collect_occurrences(expr.right, var_uid, result)
return
if isinstance(expr, UnOp):
_collect_occurrences(expr.operand, var_uid, result)
return
if isinstance(expr, Func):
_collect_occurrences(expr.arg, var_uid, result)
return
if isinstance(expr, Func2):
_collect_occurrences(expr.arg1, var_uid, result)
_collect_occurrences(expr.arg2, var_uid, result)
return
def _compute_path_signature(expr, var_uid, target_occurrence, current_path, current_depth):
if isinstance(expr, Var):
if expr.uid == var_uid:
return current_path + ("Var:placeholder",)
return None
if isinstance(expr, BinOp):
op_path = current_path + (f"BinOp:{expr.op}:left",)
left_result = _compute_path_signature(expr.left, var_uid, target_occurrence, op_path, current_depth + 1)
if left_result is not None:
return left_result
op_path = current_path + (f"BinOp:{expr.op}:right",)
right_result = _compute_path_signature(expr.right, var_uid, target_occurrence, op_path, current_depth + 1)
return right_result
if isinstance(expr, UnOp):
op_path = current_path + (f"UnOp:{expr.op}",)
return _compute_path_signature(expr.operand, var_uid, target_occurrence, op_path, current_depth + 1)
if isinstance(expr, Func):
op_path = current_path + (f"Func:{expr.op}",)
return _compute_path_signature(expr.arg, var_uid, target_occurrence, op_path, current_depth + 1)
if isinstance(expr, Func2):
op_path = current_path + (f"Func2:{expr.name}:left",)
left_result = _compute_path_signature(expr.arg1, var_uid, target_occurrence, op_path, current_depth + 1)
if left_result is not None:
return left_result
op_path = current_path + (f"Func2:{expr.name}:right",)
return _compute_path_signature(expr.arg2, var_uid, target_occurrence, op_path, current_depth + 1)
return None
[docs]
class VariableAlignmentEngine:
def __init__(self, sys1, sys2):
self.sys1 = sys1
self.sys2 = sys2
self.sys1_hashes = set()
self.sys2_hashes = set()
self._var_signatures_sys1 = {}
self._var_signatures_sys2 = {}
self._candidate_map = {}
self._mapping = {}
self._norm_sys1 = []
self._norm_sys2 = []
[docs]
def compute_mapping(self):
self._var_signatures_sys1.clear()
self._var_signatures_sys2.clear()
self._candidate_map.clear()
self._mapping.clear()
if not self._normalize_and_check_equivalence():
return {}
self._extract_signatures()
self._build_candidate_map()
if not self._backtrack_match():
return {}
if self._validate_mapping():
return dict(self._mapping)
return {}
def _normalize_and_check_equivalence(self):
self._norm_sys1 = [expand_and_canonicalize(eq) for eq in self.sys1]
self._norm_sys2 = [expand_and_canonicalize(eq) for eq in self.sys2]
self.sys1_hashes = {structural_hash(eq) for eq in self._norm_sys1}
self.sys2_hashes = {structural_hash(eq) for eq in self._norm_sys2}
return True
def _extract_signatures(self):
for eq in self.sys1:
self._extract_signatures_from_equation(eq, self._var_signatures_sys1)
for eq in self.sys2:
self._extract_signatures_from_equation(eq, self._var_signatures_sys2)
def _extract_signatures_from_equation(self, eq, signature_dict):
all_vars = self._collect_vars_in_expr(eq)
placeholder = Var(name=PLACEHOLDER_NAME)
for var_uid in all_vars:
occurrences = _get_var_occurrences(eq, var_uid)
for occ_idx in occurrences:
context_hash = self._compute_context_hash(eq, var_uid, occ_idx, placeholder)
path_sig = _compute_path_signature(eq, var_uid, occ_idx, (), 0)
if path_sig is None:
continue
depth = len(path_sig)
sig = VariableSignature(
var_uid=var_uid,
occurrence_idx=occ_idx,
context_hash=context_hash,
depth=depth,
path_signature=path_sig,
)
if var_uid not in signature_dict:
signature_dict[var_uid] = []
signature_dict[var_uid].append(sig)
def _compute_context_hash(self, eq, var_uid, occurrence_idx, placeholder):
modified_expr = _replace_all_vars_same_placeholder(eq, placeholder)
return structural_hash(canonical(modified_expr))
def _collect_vars_in_expr(self, expr):
vars_set = set()
self._collect_vars_recursive(expr, vars_set)
return vars_set
def _collect_vars_recursive(self, expr, result):
if isinstance(expr, Var):
result.add(expr.uid)
return
if isinstance(expr, BinOp):
self._collect_vars_recursive(expr.left, result)
self._collect_vars_recursive(expr.right, result)
return
if isinstance(expr, UnOp):
self._collect_vars_recursive(expr.operand, result)
return
if isinstance(expr, Func):
self._collect_vars_recursive(expr.arg, result)
return
if isinstance(expr, Func2):
self._collect_vars_recursive(expr.arg1, result)
self._collect_vars_recursive(expr.arg2, result)
return
def _build_candidate_map(self):
for uid1, sigs1 in self._var_signatures_sys1.items():
candidates = set()
for uid2, sigs2 in self._var_signatures_sys2.items():
if self._signatures_match(sigs1, sigs2):
candidates.add(uid2)
if candidates:
self._candidate_map[uid1] = candidates
def _signatures_match(self, sigs1, sigs2):
if len(sigs1) != len(sigs2):
return False
sigs1_norm = sorted(sigs1, key=lambda s: (s.occurrence_idx, s.context_hash))
sigs2_norm = sorted(sigs2, key=lambda s: (s.occurrence_idx, s.context_hash))
for s1, s2 in zip(sigs1_norm, sigs2_norm):
if s1 != s2:
return False
return True
def _backtrack_match(self):
if not self._candidate_map:
return True
sorted_vars = sorted(self._candidate_map.keys(), key=lambda u: (len(self._candidate_map[u]), u))
assigned = {}
used_sys2 = set()
return self._recursive_match(sorted_vars, 0, assigned, used_sys2)
def _recursive_match(self, sorted_vars, index, assigned, used_sys2):
if index == len(sorted_vars):
self._mapping = dict(assigned)
return True
var_uid = sorted_vars[index]
candidates = self._candidate_map[var_uid]
sorted_candidates = sorted(candidates - used_sys2)
for candidate_uid in sorted_candidates:
assigned[var_uid] = candidate_uid
used_sys2.add(candidate_uid)
if self._partial_validate(assigned):
if self._recursive_match(sorted_vars, index + 1, assigned, used_sys2):
return True
del assigned[var_uid]
used_sys2.remove(candidate_uid)
return False
def _partial_validate(self, partial_mapping):
return len(partial_mapping) == len(set(partial_mapping.values()))
def _validate_mapping(self):
if len(self._mapping) != len(set(self._mapping.values())):
return False
return True
def _verify_substitution(self):
for eq1 in self._norm_sys1:
substituted = self._substitute_variables(eq1, self._mapping)
sub_hash = structural_hash(substituted)
if sub_hash not in self.sys2_hashes:
return False
return True
def _substitute_variables(self, expr, mapping):
if isinstance(expr, Var):
if expr.uid in mapping:
return Var(name=expr.name, uid=mapping[expr.uid])
return Var(name=expr.name, uid=expr.uid)
if isinstance(expr, BinOp):
return BinOp(
self._substitute_variables(expr.left, mapping),
expr.op,
self._substitute_variables(expr.right, mapping),
)
if isinstance(expr, UnOp):
return UnOp(expr.op, self._substitute_variables(expr.operand, mapping))
if isinstance(expr, Func):
return Func(self._substitute_variables(expr.arg, mapping), expr.op)
if isinstance(expr, Func2):
return Func2(
expr.name,
self._substitute_variables(expr.arg1, mapping),
self._substitute_variables(expr.arg2, mapping),
)
return expr
def _replace_all_vars_with_placeholder(expr, target_uid, placeholder):
"""Replace all occurrences of a variable with a placeholder."""
if isinstance(expr, Var):
if expr.uid == target_uid:
return placeholder
return expr
if isinstance(expr, BinOp):
return BinOp(
_replace_all_vars_with_placeholder(expr.left, target_uid, placeholder),
expr.op,
_replace_all_vars_with_placeholder(expr.right, target_uid, placeholder),
)
if isinstance(expr, UnOp):
return UnOp(expr.op, _replace_all_vars_with_placeholder(expr.operand, target_uid, placeholder))
if isinstance(expr, Func):
return Func(_replace_all_vars_with_placeholder(expr.arg, target_uid, placeholder), expr.op)
if isinstance(expr, Func2):
return Func2(
expr.name,
_replace_all_vars_with_placeholder(expr.arg1, target_uid, placeholder),
_replace_all_vars_with_placeholder(expr.arg2, target_uid, placeholder),
)
return expr
def _replace_all_vars_same_placeholder(expr, placeholder):
"""Replace ALL variables in expr with the same placeholder."""
if isinstance(expr, Var):
return placeholder
if isinstance(expr, BinOp):
return BinOp(
_replace_all_vars_same_placeholder(expr.left, placeholder),
expr.op,
_replace_all_vars_same_placeholder(expr.right, placeholder),
)
if isinstance(expr, UnOp):
return UnOp(expr.op, _replace_all_vars_same_placeholder(expr.operand, placeholder))
if isinstance(expr, Func):
return Func(_replace_all_vars_same_placeholder(expr.arg, placeholder), expr.op)
if isinstance(expr, Func2):
return Func2(
expr.name,
_replace_all_vars_same_placeholder(expr.arg1, placeholder),
_replace_all_vars_same_placeholder(expr.arg2, placeholder),
)
return expr
[docs]
def align_variables(sys1, sys2):
engine = VariableAlignmentEngine(sys1, sys2)
return engine.compute_mapping()
[docs]
def test_simple_linear_system():
"""Test simple linear system with different variable names."""
x1 = Var("x", uid=1)
y1 = Var("y", uid=2)
x2 = Var("a", uid=10)
y2 = Var("b", uid=20)
eq1_sys1 = x1 + y1
eq2_sys1 = x1 - y1
eq1_sys2 = x2 + y2
eq2_sys2 = x2 - y2
sys1 = [eq1_sys1, eq2_sys1]
sys2 = [eq1_sys2, eq2_sys2]
mapping = align_variables(sys1, sys2)
print(f"Test simple linear system:")
print(f" mapping: {mapping}")
assert len(mapping) == 2, f"Expected 2 mappings, got {len(mapping)}"
assert mapping[1] == 10, f"x should map to a"
assert mapping[2] == 20, f"y should map to b"
print(" PASSED\n")
[docs]
def test_polynomial_system():
"""Test polynomial system equivalence."""
x1 = Var("x", uid=1)
y1 = Var("y", uid=2)
x2 = Var("u", uid=10)
y2 = Var("v", uid=20)
eq1_sys1 = (x1 + y1) ** Const(2)
eq2_sys1 = x1 ** Const(2) + Const(2) * x1 * y1 + y1 ** Const(2)
eq1_sys2 = (x2 + y2) ** Const(2)
eq2_sys2 = x2 ** Const(2) + Const(2) * x2 * y2 + y2 ** Const(2)
sys1 = [eq1_sys1, eq2_sys1]
sys2 = [eq1_sys2, eq2_sys2]
mapping = align_variables(sys1, sys2)
print(f"Test polynomial system:")
print(f" mapping: {mapping}")
assert len(mapping) == 2, f"Expected 2 mappings, got {len(mapping)}"
print(" PASSED\n")
[docs]
def test_constants_ignored():
"""Test that constants are properly ignored."""
a1 = Var("a", uid=1)
b1 = Var("b", uid=2)
a2 = Var("x", uid=10)
b2 = Var("y", uid=20)
eq1_sys1 = a1 + Const(5)
eq2_sys1 = b1 * Const(3)
eq1_sys2 = a2 + Const(5)
eq2_sys2 = b2 * Const(3)
sys1 = [eq1_sys1, eq2_sys1]
sys2 = [eq1_sys2, eq2_sys2]
mapping = align_variables(sys1, sys2)
print(f"Test constants ignored:")
print(f" mapping: {mapping}")
assert len(mapping) == 2, f"Expected 2 mappings, got {len(mapping)}"
print(" PASSED\n")
[docs]
def run_all_tests():
print("=" * 60)
print("RUNNING VARIABLE ALIGNMENT TESTS")
print("=" * 60 + "\n")
tests = [
test_simple_linear_system,
test_polynomial_system,
test_constants_ignored,
]
passed = 0
failed = 0
for test in tests:
try:
test()
passed += 1
except AssertionError as e:
print(f" FAILED: {e}\n")
failed += 1
except Exception as e:
import traceback
print(f" ERROR: {e}")
traceback.print_exc()
print()
failed += 1
print("=" * 60)
print(f"RESULTS: {passed} passed, {failed} failed")
print("=" * 60)
return failed == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)