Source code for VeraGridEngine.Utils.Symbolic.variable_alignment_engine

# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0

"""
VariableAlignmentEngine - Exact 1-to-1 Variable Mapping for Equivalent Systems.

Computes a deterministic bijection between variables of two mathematically
equivalent systems of equations using structural hashing and backtracking.
"""

from __future__ import annotations

from typing import Any, Dict, List, Optional, Set, Tuple

from VeraGridEngine.Utils.Symbolic.symbolic import Expr, Const, Var, BinOp, UnOp, Func, Func2
from VeraGridEngine.Utils.Symbolic.compare_expressions_structure import (
    canonical,
    structural_hash,
    expand_and_canonicalize,
)


PLACEHOLDER_NAME = "_"


[docs] class VariableSignature: """Immutable signature for a variable occurrence in an equation context.""" __slots__ = ("var_uid", "occurrence_idx", "context_hash", "depth", "path_signature") def __init__( self, var_uid: int, occurrence_idx: int, context_hash: str, depth: int, path_signature: Tuple[str, ...], ): self.var_uid = var_uid self.occurrence_idx = occurrence_idx self.context_hash = context_hash self.depth = depth self.path_signature = path_signature def __hash__(self) -> int: return hash(( self.var_uid, self.occurrence_idx, self.context_hash, self.depth, self.path_signature, )) def __eq__(self, other: object) -> bool: if not isinstance(other, VariableSignature): return False return ( self.occurrence_idx == other.occurrence_idx and self.context_hash == other.context_hash and self.depth == other.depth and self.path_signature == other.path_signature ) def __repr__(self) -> str: return f"VarSig(uid={self.var_uid}, occ={self.occurrence_idx}, ctx={self.context_hash[:6]}, depth={self.depth})"
def _replace_var_at_position( expr: Expr, target_uid: int, placeholder: Var, occurrence_counter: List[int], ) -> Tuple[Expr, int]: """Replace a variable occurrence with a placeholder and return the occurrence index.""" if isinstance(expr, Var): if expr.uid == target_uid: occurrence_counter[0] += 1 return placeholder, occurrence_counter[0] - 1 return expr, -1 if isinstance(expr, BinOp): new_left, idx = _replace_var_at_position(expr.left, target_uid, placeholder, occurrence_counter) new_right, idx2 = _replace_var_at_position(expr.right, target_uid, placeholder, occurrence_counter) if idx == -1 and idx2 != -1: idx = idx2 return BinOp(new_left, expr.op, new_right), idx if isinstance(expr, UnOp): new_operand, idx = _replace_var_at_position(expr.operand, target_uid, placeholder, occurrence_counter) return UnOp(expr.op, new_operand), idx if isinstance(expr, Func): new_arg, idx = _replace_var_at_position(expr.arg, target_uid, placeholder, occurrence_counter) return Func(new_arg, expr.op), idx if isinstance(expr, Func2): new_arg1, idx1 = _replace_var_at_position(expr.arg1, target_uid, placeholder, occurrence_counter) new_arg2, idx2 = _replace_var_at_position(expr.arg2, target_uid, placeholder, occurrence_counter) idx = idx1 if idx1 != -1 else idx2 return Func2(expr.name, new_arg1, new_arg2), idx return expr, -1 def _get_var_occurrences(expr: Expr, var_uid: int) -> List[int]: """Get list of occurrence indices where var_uid appears in expr.""" occurrences: List[int] = [] _collect_occurrences(expr, var_uid, occurrences) return occurrences def _get_all_var_occurrences(expr: Expr) -> Dict[int, List[int]]: """Get occurrences for all variables in expr in a single pass. Returns dict mapping var_uid to list of occurrence indices.""" result: Dict[int, List[int]] = {} global_counter = [0] _collect_all_occurrences(expr, result, global_counter) return result def _collect_all_occurrences(expr: Expr, result: Dict[int, List[int]], counter: List[int]) -> None: """Collect occurrence indices for all variables in a single pass.""" _collect_occurrences_rec(expr, result, counter) def _collect_occurrences_rec(expr: Expr, result: Dict[int, List[int]], counter: List[int]) -> None: """Recursive helper to collect all variable occurrences.""" if isinstance(expr, Var): if expr.uid not in result: result[expr.uid] = [] result[expr.uid].append(counter[0]) counter[0] += 1 return if isinstance(expr, BinOp): _collect_occurrences_rec(expr.left, result, counter) _collect_occurrences_rec(expr.right, result, counter) return if isinstance(expr, UnOp): _collect_occurrences_rec(expr.operand, result, counter) return if isinstance(expr, Func): _collect_occurrences_rec(expr.arg, result, counter) return if isinstance(expr, Func2): _collect_occurrences_rec(expr.arg1, result, counter) _collect_occurrences_rec(expr.arg2, result, counter) return def _collect_occurrences(expr: Expr, var_uid: int, result: List[int]) -> None: """Recursively collect occurrence indices.""" if isinstance(expr, Var): if expr.uid == var_uid: result.append(len(result)) return if isinstance(expr, BinOp): _collect_occurrences(expr.left, var_uid, result) _collect_occurrences(expr.right, var_uid, result) return if isinstance(expr, UnOp): _collect_occurrences(expr.operand, var_uid, result) return if isinstance(expr, Func): _collect_occurrences(expr.arg, var_uid, result) return if isinstance(expr, Func2): _collect_occurrences(expr.arg1, var_uid, result) _collect_occurrences(expr.arg2, var_uid, result) return def _compute_all_path_signatures( expr: Expr, var_occurrences: Dict[int, List[int]], ) -> Dict[Tuple[int, int], Tuple[str, ...]]: """Compute path signatures for all variable occurrences in a single traversal. Returns dict mapping (var_uid, occurrence_idx) -> path_signature. """ result: Dict[Tuple[int, int], Tuple[str, ...]] = {} target_occurrences: Dict[int, int] = {uid: 0 for uid in var_occurrences} _compute_paths_rec(expr, (), 0, var_occurrences, target_occurrences, result) return result def _compute_paths_rec( expr: Expr, current_path: Tuple[str, ...], current_depth: int, var_occurrences: Dict[int, List[int]], target_occurrences: Dict[int, int], result: Dict[Tuple[int, int], Tuple[str, ...]], ) -> None: """Recursive helper to compute all path signatures in one pass.""" if isinstance(expr, Var): uid = expr.uid if uid in var_occurrences: occ_list = var_occurrences[uid] target_idx = target_occurrences[uid] if target_idx < len(occ_list): result[(uid, occ_list[target_idx])] = current_path + ("Var:placeholder",) target_occurrences[uid] += 1 return if isinstance(expr, BinOp): op_path = current_path + (f"BinOp:{expr.op}:left",) _compute_paths_rec(expr.left, op_path, current_depth + 1, var_occurrences, target_occurrences, result) op_path = current_path + (f"BinOp:{expr.op}:right",) _compute_paths_rec(expr.right, op_path, current_depth + 1, var_occurrences, target_occurrences, result) return if isinstance(expr, UnOp): op_path = current_path + (f"UnOp:{expr.op}",) _compute_paths_rec(expr.operand, op_path, current_depth + 1, var_occurrences, target_occurrences, result) return if isinstance(expr, Func): op_path = current_path + (f"Func:{expr.op}",) _compute_paths_rec(expr.arg, op_path, current_depth + 1, var_occurrences, target_occurrences, result) return if isinstance(expr, Func2): op_path = current_path + (f"Func2:{expr.name}:left",) _compute_paths_rec(expr.arg1, op_path, current_depth + 1, var_occurrences, target_occurrences, result) op_path = current_path + (f"Func2:{expr.name}:right",) _compute_paths_rec(expr.arg2, op_path, current_depth + 1, var_occurrences, target_occurrences, result) return def _replace_all_vars_with_placeholder(expr: Expr, target_uid: int, placeholder: Var) -> Expr: """Replace all occurrences of a variable with a placeholder.""" if isinstance(expr, Var): if expr.uid == target_uid: return placeholder return expr if isinstance(expr, BinOp): return BinOp( _replace_all_vars_with_placeholder(expr.left, target_uid, placeholder), expr.op, _replace_all_vars_with_placeholder(expr.right, target_uid, placeholder), ) if isinstance(expr, UnOp): return UnOp(expr.op, _replace_all_vars_with_placeholder(expr.operand, target_uid, placeholder)) if isinstance(expr, Func): return Func(_replace_all_vars_with_placeholder(expr.arg, target_uid, placeholder), expr.op) if isinstance(expr, Func2): return Func2( expr.name, _replace_all_vars_with_placeholder(expr.arg1, target_uid, placeholder), _replace_all_vars_with_placeholder(expr.arg2, target_uid, placeholder), ) return expr def _replace_all_vars_same_placeholder(expr: Expr, placeholder: Var) -> Expr: """Replace ALL variables in expr with the same placeholder.""" if isinstance(expr, Var): return placeholder if isinstance(expr, BinOp): return BinOp( _replace_all_vars_same_placeholder(expr.left, placeholder), expr.op, _replace_all_vars_same_placeholder(expr.right, placeholder), ) if isinstance(expr, UnOp): return UnOp(expr.op, _replace_all_vars_same_placeholder(expr.operand, placeholder)) if isinstance(expr, Func): return Func(_replace_all_vars_same_placeholder(expr.arg, placeholder), expr.op) if isinstance(expr, Func2): return Func2( expr.name, _replace_all_vars_same_placeholder(expr.arg1, placeholder), _replace_all_vars_same_placeholder(expr.arg2, placeholder), ) return expr def _replace_all_vars_with_name_placeholder(expr: Expr) -> Expr: """Replace each variable with a placeholder named after the variable's name. This allows distinguishing structurally identical equations that contain differently-named variables (e.g. Pf vs Pt equations). """ if isinstance(expr, Var): return Var(name=f"PH_{expr.name}") if isinstance(expr, BinOp): return BinOp( _replace_all_vars_with_name_placeholder(expr.left), expr.op, _replace_all_vars_with_name_placeholder(expr.right), ) if isinstance(expr, UnOp): return UnOp(expr.op, _replace_all_vars_with_name_placeholder(expr.operand)) if isinstance(expr, Func): return Func(_replace_all_vars_with_name_placeholder(expr.arg), expr.op) if isinstance(expr, Func2): return Func2( expr.name, _replace_all_vars_with_name_placeholder(expr.arg1), _replace_all_vars_with_name_placeholder(expr.arg2), ) return expr
[docs] class VariableAlignmentEngine: """Computes exact 1-to-1 mapping between variables of equivalent systems.""" def __init__(self, sys1: List[Expr], sys2: List[Expr]): self.sys1 = sys1 self.sys2 = sys2 self.sys1_hashes: Set[str] = set() self.sys2_hashes: Set[str] = set() self._var_signatures_sys1: Dict[int, List[VariableSignature]] = {} self._var_signatures_sys2: Dict[int, List[VariableSignature]] = {} self._candidate_map: Dict[int, Set[int]] = {} self._mapping: Dict[int, int] = {} self._norm_sys1: List[Expr] = [] self._norm_sys2: List[Expr] = []
[docs] def compute_mapping(self) -> Dict[int, int]: """Main entry point. Computes the variable mapping.""" self._var_signatures_sys1.clear() self._var_signatures_sys2.clear() self._candidate_map.clear() self._mapping.clear() self._normalize_and_check_equivalence() self._extract_signatures() self._build_candidate_map() if not self._backtrack_match(): return {} if self._validate_mapping(): return dict(self._mapping) return {}
def _normalize_and_check_equivalence(self) -> bool: """Normalize systems.""" self._norm_sys1 = [expand_and_canonicalize(eq) for eq in self.sys1] self._norm_sys2 = [expand_and_canonicalize(eq) for eq in self.sys2] self.sys1_hashes = {structural_hash(eq) for eq in self._norm_sys1} self.sys2_hashes = {structural_hash(eq) for eq in self._norm_sys2} return True def _extract_signatures(self) -> None: """Extract signatures for all variables in both systems.""" for eq in self.sys1: self._extract_signatures_from_equation(eq, self._var_signatures_sys1) for eq in self.sys2: self._extract_signatures_from_equation(eq, self._var_signatures_sys2) def _extract_signatures_from_equation( self, eq: Expr, signature_dict: Dict[int, List[VariableSignature]], ) -> None: """Extract signatures for all Var nodes in an equation.""" all_occurrences = _get_all_var_occurrences(eq) if not all_occurrences: return # Use name-based placeholders so equations with differently-named variables # (e.g. Pf vs Pt) produce different context hashes, even if they are # structurally identical after blanket placeholder replacement. placeholder = Var(name=PLACEHOLDER_NAME) all_name_placeholder_expr = _replace_all_vars_with_name_placeholder(eq) all_placeholder_hash = structural_hash(canonical(all_name_placeholder_expr)) all_paths = _compute_all_path_signatures(eq, all_occurrences) for var_uid, occurrences in all_occurrences.items(): for occ_idx in occurrences: path_sig = all_paths.get((var_uid, occ_idx)) if path_sig is None: continue depth = len(path_sig) sig = VariableSignature( var_uid=var_uid, occurrence_idx=occ_idx, context_hash=all_placeholder_hash, depth=depth, path_signature=path_sig, ) if var_uid not in signature_dict: signature_dict[var_uid] = [] signature_dict[var_uid].append(sig) def _compute_context_hash( self, eq: Expr, var_uid: int, occurrence_idx: int, placeholder: Var, ) -> str: """Compute context hash by replacing var with placeholder and hashing.""" modified_expr = _replace_all_vars_same_placeholder(eq, placeholder) return structural_hash(canonical(modified_expr)) def _collect_vars_in_expr(self, expr: Expr) -> Set[int]: """Collect all variable uids in an expression.""" vars_set: Set[int] = set() self._collect_vars_recursive(expr, vars_set) return vars_set def _collect_vars_recursive(self, expr: Expr, result: Set[int]) -> None: """Recursively collect variable uids.""" if isinstance(expr, Var): result.add(expr.uid) return if isinstance(expr, BinOp): self._collect_vars_recursive(expr.left, result) self._collect_vars_recursive(expr.right, result) return if isinstance(expr, UnOp): self._collect_vars_recursive(expr.operand, result) return if isinstance(expr, Func): self._collect_vars_recursive(expr.arg, result) return if isinstance(expr, Func2): self._collect_vars_recursive(expr.arg1, result) self._collect_vars_recursive(expr.arg2, result) return def _build_candidate_map(self) -> None: """Build bipartite matching candidates based on signature equality.""" for uid1, sigs1 in self._var_signatures_sys1.items(): candidates: Set[int] = set() for uid2, sigs2 in self._var_signatures_sys2.items(): if self._signatures_match(sigs1, sigs2): candidates.add(uid2) if candidates: self._candidate_map[uid1] = candidates def _signatures_match( self, sigs1: List[VariableSignature], sigs2: List[VariableSignature], ) -> bool: """Check if two sets of signatures represent the same variable positions.""" if len(sigs1) != len(sigs2): return False sigs1_norm = sorted(sigs1, key=lambda s: (s.occurrence_idx, s.context_hash)) sigs2_norm = sorted(sigs2, key=lambda s: (s.occurrence_idx, s.context_hash)) for s1, s2 in zip(sigs1_norm, sigs2_norm): if s1 != s2: return False return True def _backtrack_match(self) -> bool: """Perform backtracking to find valid 1-to-1 mapping.""" if not self._candidate_map: return True sorted_vars = sorted( self._candidate_map.keys(), key=lambda u: (len(self._candidate_map[u]), u), ) assigned: Dict[int, int] = {} used_sys2: Set[int] = set() return self._recursive_match(sorted_vars, 0, assigned, used_sys2) def _recursive_match( self, sorted_vars: List[int], index: int, assigned: Dict[int, int], used_sys2: Set[int], ) -> bool: """Recursive backtracking with MRV heuristic.""" if index == len(sorted_vars): self._mapping = dict(assigned) return True var_uid = sorted_vars[index] candidates = self._candidate_map[var_uid] sorted_candidates = sorted(candidates - used_sys2) for candidate_uid in sorted_candidates: assigned[var_uid] = candidate_uid used_sys2.add(candidate_uid) if self._partial_validate(assigned): if self._recursive_match(sorted_vars, index + 1, assigned, used_sys2): return True del assigned[var_uid] used_sys2.remove(candidate_uid) return False def _partial_validate(self, partial_mapping: Dict[int, int]) -> bool: """Validate partial mapping for injectivity.""" return len(partial_mapping) == len(set(partial_mapping.values())) def _validate_mapping(self) -> bool: """Full validation of the computed mapping.""" return len(self._mapping) == len(set(self._mapping.values()))
[docs] def align_variables(sys1: List[Expr], sys2: List[Expr]) -> Dict[int, int]: """ Compute exact 1-to-1 variable mapping between two equivalent systems. Args: sys1: First system of equations sys2: Second system of equations Returns: Dict[int, int]: Mapping from sys1 variable uids to sys2 variable uids. Empty dict if mapping fails. """ engine = VariableAlignmentEngine(sys1, sys2) return engine.compute_mapping()
__all__ = [ "VariableAlignmentEngine", "VariableSignature", "align_variables", ]