# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
"""
Module: JIT Equation Compiler & Numerical Discretization Engine
===============================================================
Abstract
--------
This module implements a symbolic-to-numeric translation engine designed to generate
optimized executable kernels for Differential-Algebraic Equation (DAE) systems.
It operates by traversing the symbolic Abstract Syntax Tree (AST) of the model and
injecting discrete time-stepping schemes (e.g., Implicit Trapezoidal, BDF2) directly
into the residual evaluation code.
Architectural Rationale
-----------------------
The decoupling of the compilation logic from the `DiffBlockSolver` enforces a strict
separation of concerns between the network topology management (Solver) and the
numerical discretization strategy (Compiler). This abstraction allows for the
interchangeability of integration methods—facilitating the switch between A-stable
and L-stable schemes—without altering the solver's core iterative algorithms.
Performance Optimization
------------------------
By fusing the physical model equations with the numerical integration formulas into
a single JIT-compiled kernel, this module eliminates the interpreter overhead and
intermediate memory allocations typical of vectorized Numpy operations. This results
in a monolithic residual evaluation function optimized for the high-frequency
computation requirements of Electromagnetic Transient (EMT) simulations.
"""
# Import libraries for utils
import numpy as np
from abc import ABC, abstractmethod
from enum import Enum
from numba import types
from typing import Dict, List, Callable, Tuple, Set, Any
import os
from pathlib import Path
import importlib.util
import hashlib
import scipy.sparse as sp
from collections import defaultdict
#RMS
from VeraGridEngine.Utils.Symbolic.symbolic import expression2numba, get_expression_vars
from VeraGridEngine.enumerations import DynamicIntegrationMethod
#Totally integrated
from VeraGridEngine.Utils.Symbolic.symbolic import Expr, BinOp, UnOp, Func, Var, Const
from VeraGridEngine.basic_structures import Vec
[docs]
class GeneratedKernelCacheEntry:
"""
In-process cache entry for one generated eager kernel.
"""
__slots__ = ["_python_function", "_signature_tpe"]
def __init__(self, python_function: Callable, signature_tpe: object) -> None:
"""
Build one generated-kernel cache entry.
:param python_function: Generated Python function.
:type python_function: Callable
:param signature_tpe: Eager Numba signature.
:type signature_tpe: object
:return: None.
:rtype: None
"""
self._python_function = python_function
self._signature_tpe = signature_tpe
[docs]
def get_python_function(self) -> Callable:
"""
Return the generated Python function.
:return: Generated Python function.
:rtype: Callable
"""
return self._python_function
[docs]
def get_signature_tpe(self) -> object:
"""
Return the eager Numba signature.
:return: Eager Numba signature.
:rtype: object
"""
return self._signature_tpe
def _get_cse_sort_key(entry: Tuple[str, str]) -> int:
"""
Return the numeric ordering key of one generated CSE temporary.
:param entry: Pair ``(expression_hash, temporary_name)``.
:return: Integer suffix used by the generated temporary name.
"""
return int(entry[1][2:])
def _get_triplet_sort_key(triplet: Tuple[int, int, Expr]) -> Tuple[int, int]:
"""
Return the deterministic column-row ordering of one sparse Jacobian triplet.
:param triplet: Sparse Jacobian triplet ``(col, row, expr)``.
:return: Sorting key ``(col, row)``.
"""
return triplet[0], triplet[1]
def _collect_related_wrt_uids(var: Var,
wrt_map: Dict[int, Tuple[int, Var]],
output_uids: Set[int],
seen_uids: Set[int]) -> None:
"""
Collect all differentiated UIDs related to one symbolic variable chain.
:param var: Symbolic variable from one expression.
:param wrt_map: Differentiation variables keyed by UID.
:param output_uids: Output UID set updated in place.
:param seen_uids: Visited UID set used to avoid recursion loops.
:return: None.
"""
if var.uid in seen_uids:
return
else:
pass
seen_uids.add(var.uid)
if var.uid in wrt_map:
output_uids.add(var.uid)
else:
pass
if var.base_var is not None:
_collect_related_wrt_uids(var.base_var, wrt_map, output_uids, seen_uids)
else:
pass
if var.diff_var is not None:
_collect_related_wrt_uids(var.diff_var, wrt_map, output_uids, seen_uids)
else:
pass
def _collect_candidate_wrt_uids(eq: Expr, wrt_map: Dict[int, Tuple[int, Var]]) -> Set[int]:
"""
Return the differentiated variable UIDs that appear in one equation.
:param eq: Symbolic equation.
:param wrt_map: Differentiation variables keyed by UID.
:return: Candidate UIDs present in the equation dependency graph.
"""
candidates: Set[int] = set()
seen_uids: Set[int] = set()
variable_in_equation: Var
for variable_in_equation in get_expression_vars(eq):
_collect_related_wrt_uids(variable_in_equation, wrt_map, candidates, seen_uids)
return candidates
[docs]
class EmptySparseJacobianEvaluator:
"""
Reusable callable returning one prebuilt empty CSC Jacobian matrix.
"""
__slots__ = ("_matrix",)
def __init__(self, matrix: sp.csc_matrix) -> None:
"""
Store the empty sparse Jacobian matrix.
:param matrix: Prebuilt empty CSC matrix.
:return: None.
"""
self._matrix: sp.csc_matrix = matrix
def __call__(self, vrs: Vec, diff: Vec, vprms: Vec, cprms: Vec, h: float) -> sp.csc_matrix:
"""
Return the prebuilt empty sparse Jacobian matrix.
:param vrs: Runtime variables.
:param diff: Differential variables.
:param vprms: Variable parameters.
:param cprms: Constant parameters.
:param h: Integration step.
:return: Empty CSC Jacobian matrix.
"""
_unused_args: Tuple[Vec, Vec, Vec, Vec, float] = (vrs, diff, vprms, cprms, h)
return self._matrix
[docs]
class EmptyVecSparseJacobianEvaluator:
"""
Reusable callable returning a zero 2D data array for vectorized Jacobians.
"""
__slots__ = ("_n_rows", "_n_cols")
def __init__(self, n_rows: int, n_cols: int) -> None:
self._n_rows: int = n_rows
self._n_cols: int = n_cols
def __call__(self, vrs: Vec, diff: Vec, vprms: Vec, cprms: Vec, h: float) -> np.ndarray:
n_inst: int = vrs.shape[1]
return np.zeros((0, n_inst), dtype=np.float64)
[docs]
def get_sparsity(self) -> Tuple[np.ndarray, np.ndarray, int, int]:
n_vars: int = self._n_cols
return (np.zeros(0, dtype=np.int32),
np.zeros(n_vars + 1, dtype=np.int32),
self._n_rows,
n_vars)
[docs]
class SparseJacobianEvaluatorVecWrapper:
"""
Callable wrapper around a vectorized sparse Jacobian filler.
The filler function fills a 2D data_out array of shape (nnz, n_instances).
This wrapper allocates the output array, calls the filler, and returns it.
It also exposes the sparsity pattern (row indices, column pointers) so the
problem class can assemble the global sparse Jacobian.
"""
__slots__ = ("_filler_fn", "_nnz", "_n_rows", "_n_cols",
"_rows", "_cols", "_indices", "_indptr")
def __init__(self, filler_fn: Callable, nnz: int, n_rows: int, n_cols: int,
rows: List[int], cols: List[int],
indices: np.ndarray, indptr: np.ndarray) -> None:
self._filler_fn: Callable = filler_fn
self._nnz: int = nnz
self._n_rows: int = n_rows
self._n_cols: int = n_cols
self._rows: List[int] = rows
self._cols: List[int] = cols
self._indices: np.ndarray = indices
self._indptr: np.ndarray = indptr
def __call__(self, vrs: Vec, diff: Vec, vprms: Vec, cprms: Vec, h: float) -> np.ndarray:
n_inst: int = vrs.shape[1]
data_out: np.ndarray = np.zeros((self._nnz, n_inst), dtype=np.float64)
self._filler_fn(vrs, diff, vprms, cprms, h, data_out)
return data_out
[docs]
def get_sparsity(self) -> Tuple[np.ndarray, np.ndarray, int, int]:
return self._indices, self._indptr, self._n_rows, self._n_cols
[docs]
class SparseJacobianEvaluatorWrapper:
"""
Callable sparse Jacobian wrapper around one generated CSC filler function.
"""
__slots__ = ("_filler_fn", "_matrix")
def __init__(self, filler_fn: Callable, matrix: sp.csc_matrix) -> None:
"""
Store the generated filler function and the reusable CSC matrix shell.
:param filler_fn: Generated sparse Jacobian filler function.
:param matrix: Reusable CSC matrix shell.
:return: None.
"""
self._filler_fn: Callable = filler_fn
self._matrix: sp.csc_matrix = matrix
def __call__(self, vrs: Vec, diff: Vec, vprms: Vec, cprms: Vec, h: float) -> sp.csc_matrix:
"""
Fill and return the reusable CSC Jacobian matrix.
:param vrs: Runtime variables.
:param diff: Differential variables.
:param vprms: Variable parameters.
:param cprms: Constant parameters.
:param h: Integration step.
:return: Filled CSC Jacobian matrix.
"""
self._filler_fn(vrs, diff, vprms, cprms, h, self._matrix.data)
return self._matrix
[docs]
class EventParameterFunctionWrapper:
"""
Callable runtime-parameter wrapper around one generated event-parameter kernel.
"""
__slots__ = ("_raw_fn", "_equation_count")
def __init__(self, raw_fn: Callable, equation_count: int) -> None:
"""
Store the generated event-parameter kernel.
:param raw_fn: Generated event-parameter kernel.
:param equation_count: Number of output equations.
:return: None.
"""
self._raw_fn: Callable = raw_fn
self._equation_count: int = equation_count
def __call__(self, event_params: Vec, glob_time: float) -> Vec:
"""
Evaluate the event-parameter kernel into a fresh output vector.
:param event_params: Runtime parameter vector.
:param glob_time: Current simulation time.
:return: Evaluated runtime-parameter vector.
"""
output_vector: Vec = np.zeros(self._equation_count, dtype=np.float64)
self._raw_fn(event_params, glob_time, output_vector)
return output_vector
[docs]
class DerivativeFunctionWrapper:
"""
Callable derivative wrapper around one generated lag-derivative kernel.
"""
__slots__ = ("_raw_fn", "_diff_var_count")
def __init__(self, raw_fn: Callable, diff_var_count: int) -> None:
"""
Store the generated derivative kernel.
:param raw_fn: Generated derivative kernel.
:param diff_var_count: Number of derivative outputs.
:return: None.
"""
self._raw_fn: Callable = raw_fn
self._diff_var_count: int = diff_var_count
def __call__(self, vrs: Vec, lagvars: Vec, lagdx: Vec, h: float) -> Vec:
"""
Evaluate the derivative kernel into a fresh output vector.
:param vrs: Runtime variable vector.
:param lagvars: Lagged state vector.
:param lagdx: Lagged derivative vector.
:param h: Integration step.
:return: Evaluated derivative vector.
"""
output_vector: Vec = np.zeros(self._diff_var_count, dtype=np.float64)
self._raw_fn(vrs, lagvars, lagdx, h, output_vector)
return output_vector
def _build_equation_compiler_residual_cache_key(equations: List[Expr],
var_map: Dict[int, int],
param_map: Dict[int, int],
method_name: str,
use_cse: bool,
offset: int,
inplace: bool) -> str:
"""
Return a deterministic cache key for one ``EquationCompiler`` residual kernel.
:param equations: Residual equations.
:type equations: List[Expr]
:param var_map: Runtime variable index map.
:type var_map: Dict[int, int]
:param param_map: Parameter index map.
:type param_map: Dict[int, int]
:param method_name: Discretization strategy name.
:type method_name: str
:param use_cse: Whether CSE is enabled.
:type use_cse: bool
:param offset: Output offset.
:type offset: int
:param inplace: Whether the kernel writes in place.
:type inplace: bool
:return: Deterministic cache key.
:rtype: str
"""
expr_fingerprints: List[str] = list()
equation_index: int = 0
while equation_index < len(equations):
expr_fingerprints.append(
_fingerprint_codegen_expr(
equations[equation_index],
var_map,
param_map,
method_name,
None,
)
)
equation_index += 1
payload: str = "|".join([
"eq-compiler-residual",
method_name,
str(use_cse),
str(offset),
str(inplace),
str(len(var_map)),
str(len(param_map)),
"::".join(expr_fingerprints),
])
return _build_codegen_cache_key(payload)
def _build_equation_compiler_ad_cache_key(equations: List[Expr],
var_map: Dict[int, int],
param_map: Dict[int, int],
method_name: str,
use_cse: bool,
active_indices: set | None) -> str:
"""
Return a deterministic cache key for one ``EquationCompiler`` AD kernel.
:param equations: Residual equations.
:type equations: List[Expr]
:param var_map: Runtime variable index map.
:type var_map: Dict[int, int]
:param param_map: Parameter index map.
:type param_map: Dict[int, int]
:param method_name: Discretization strategy name.
:type method_name: str
:param use_cse: Whether CSE is enabled.
:type use_cse: bool
:param active_indices: Optional active seed indices.
:type active_indices: set | None
:return: Deterministic cache key.
:rtype: str
"""
expr_fingerprints: List[str] = list()
active_index_tokens: List[str] = list()
equation_index: int = 0
while equation_index < len(equations):
expr_fingerprints.append(
_fingerprint_codegen_expr(
equations[equation_index],
var_map,
param_map,
method_name,
None,
)
)
equation_index += 1
if active_indices is None:
active_index_tokens = list()
else:
active_index_tokens = [str(index) for index in sorted(list(active_indices))]
payload: str = "|".join([
"eq-compiler-ad",
method_name,
str(use_cse),
str(len(var_map)),
str(len(param_map)),
",".join(active_index_tokens),
"::".join(expr_fingerprints),
])
return _build_codegen_cache_key(payload)
def _build_equation_compiler_matrix_cache_key(template_eq: Expr,
var_map: Dict[int, int],
param_map: Dict[int, int],
method_name: str,
col_map: Dict[str, int]) -> str:
"""
Return a deterministic cache key for one ``MatrixVectorizedCompiler`` kernel.
:param template_eq: Structural template equation.
:type template_eq: Expr
:param var_map: Runtime variable index map.
:type var_map: Dict[int, int]
:param param_map: Parameter index map.
:type param_map: Dict[int, int]
:param method_name: Discretization strategy name.
:type method_name: str
:param col_map: Matrix-kernel column map.
:type col_map: Dict[str, int]
:return: Deterministic cache key.
:rtype: str
"""
sorted_col_items: List[str] = list()
col_name: str
for col_name in sorted(col_map.keys()):
sorted_col_items.append(f"{col_name}:{col_map[col_name]}")
payload: str = "|".join([
"eq-compiler-matrix",
method_name,
str(len(var_map)),
str(len(param_map)),
",".join(sorted_col_items),
_fingerprint_codegen_expr(template_eq, var_map, param_map, method_name, col_map),
])
return _build_codegen_cache_key(payload)
[docs]
class GeneratedKernelCache:
"""
In-process cache for generated eager kernels keyed by structural signature.
"""
__slots__ = ["_residual_cache", "_ad_cache", "_matrix_cache"]
def __init__(self) -> None:
"""
Build the generated-kernel cache.
:return: None.
:rtype: None
"""
self._residual_cache: Dict[str, GeneratedKernelCacheEntry] = dict()
self._ad_cache: Dict[str, GeneratedKernelCacheEntry] = dict()
self._matrix_cache: Dict[str, GeneratedKernelCacheEntry] = dict()
[docs]
def get_entry(self, cache_kind: str, cache_key: str) -> GeneratedKernelCacheEntry | None:
"""
Return one cached kernel entry.
:param cache_kind: Cache family identifier.
:type cache_kind: str
:param cache_key: Deterministic cache key.
:type cache_key: str
:return: Cached entry or ``None``.
:rtype: GeneratedKernelCacheEntry | None
"""
if cache_kind == "residual":
return self._residual_cache.get(cache_key, None)
elif cache_kind == "ad":
return self._ad_cache.get(cache_key, None)
elif cache_kind == "matrix":
return self._matrix_cache.get(cache_key, None)
else:
raise ValueError(f"Unknown generated-kernel cache kind: {cache_kind}")
[docs]
def set_entry(self, cache_kind: str, cache_key: str, entry: GeneratedKernelCacheEntry) -> None:
"""
Store one cached kernel entry.
:param cache_kind: Cache family identifier.
:type cache_kind: str
:param cache_key: Deterministic cache key.
:type cache_key: str
:param entry: Cache entry.
:type entry: GeneratedKernelCacheEntry
:return: None.
:rtype: None
"""
if cache_kind == "residual":
self._residual_cache[cache_key] = entry
elif cache_kind == "ad":
self._ad_cache[cache_key] = entry
elif cache_kind == "matrix":
self._matrix_cache[cache_key] = entry
else:
raise ValueError(f"Unknown generated-kernel cache kind: {cache_kind}")
GENERATED_KERNEL_CACHE = GeneratedKernelCache()
def _is_zero(s: str) -> bool:
return s in ["0.0", "0", "0.00", "(-0.0)"]
def _is_one(s: str) -> bool:
return s in ["1.0", "1", "1.00"]
def _stable_digest(payload: str, length: int | None = None) -> str:
"""
Return a deterministic SHA-256 digest for cache keys and generated-module names.
:param payload: Deterministic payload string.
:type payload: str
:param length: Optional prefix length to keep.
:type length: int | None
:return: Hex digest string.
:rtype: str
"""
digest = hashlib.sha256(payload.encode("utf-8")).hexdigest()
return digest if length is None else digest[:length]
def _compile_to_file(full_source: str, func_name: str) -> Callable:
"""
Writes source code to a file in __pycache_jit__ and imports it.
This allows Numba's cache=True to work.
"""
import sys
header = "import numpy as np\nimport math\nfrom VeraGridEngine.Utils.Symbolic.symbolic import heaviside_num as _heaviside\n\n"
full_content = header + full_source
repo_root = Path(__file__).resolve().parents[4]
cache_dir = str(repo_root / "__pycache_jit__")
os.makedirs(cache_dir, exist_ok=True)
if cache_dir not in sys.path:
sys.path.append(cache_dir)
content_hash = _stable_digest(full_content, length=16)
mod_name = f"{func_name}_{content_hash}"
filename = f"{mod_name}.py"
filepath = os.path.join(cache_dir, filename)
if not os.path.exists(filepath):
with open(filepath, "w") as f:
f.write(full_content)
if mod_name in sys.modules:
return sys.modules[mod_name].__dict__[func_name]
spec = importlib.util.spec_from_file_location(mod_name, filepath)
module = importlib.util.module_from_spec(spec)
sys.modules[mod_name] = module
spec.loader.exec_module(module)
return module.__dict__[func_name]
def _build_codegen_cache_key(payload: str) -> str:
"""
Return a deterministic cache key for generated eager kernels.
:param payload: Deterministic payload string.
:type payload: str
:return: Cache key.
:rtype: str
"""
cache_version: str = "jit-codegen-v1"
return _stable_digest(cache_version + "|" + payload)
def _fingerprint_codegen_var(node: Var,
var_map: Dict[int, int],
param_map: Dict[int, int],
method_name: str,
col_map: Dict[str, int] | None = None) -> str:
"""
Return a stable code-generation fingerprint for one variable node.
:param node: Variable node.
:type node: Var
:param var_map: Runtime variable index map.
:type var_map: Dict[int, int]
:param param_map: Parameter index map.
:type param_map: Dict[int, int]
:param method_name: Discretization strategy name.
:type method_name: str
:param col_map: Optional matrix-kernel column map.
:type col_map: Dict[str, int] | None
:return: Stable fingerprint string.
:rtype: str
"""
if node.uid in param_map:
return f"P{param_map[node.uid]}"
elif col_map is None and node.uid in var_map:
return f"S{var_map[node.uid]}"
elif col_map is not None and node.uid in var_map:
mapped_name: str = node.name
if mapped_name in col_map:
return f"M{col_map[mapped_name]}"
elif node.base_var is not None and node.origin_var.name in col_map:
return f"M{col_map[node.origin_var.name]}"
else:
return f"MN:{mapped_name}"
elif node.base_var is not None:
if node.base_var.base_var is not None:
nested_fp: str = _fingerprint_codegen_var(node.base_var, var_map, param_map, method_name, col_map)
return f"DD[{method_name}|{nested_fp}]"
elif col_map is None:
return f"D[{method_name}|S{var_map[node.base_var.uid]}]"
else:
return f"D[{method_name}|M{col_map[node.origin_var.name]}]"
elif node.name == 'h':
return "H"
else:
return f"N:{node.name}"
def _fingerprint_codegen_expr(node: Expr,
var_map: Dict[int, int],
param_map: Dict[int, int],
method_name: str,
col_map: Dict[str, int] | None = None) -> str:
"""
Return a stable code-generation fingerprint for one symbolic expression.
:param node: Symbolic expression node.
:type node: Expr
:param var_map: Runtime variable index map.
:type var_map: Dict[int, int]
:param param_map: Parameter index map.
:type param_map: Dict[int, int]
:param method_name: Discretization strategy name.
:type method_name: str
:param col_map: Optional matrix-kernel column map.
:type col_map: Dict[str, int] | None
:return: Stable fingerprint string.
:rtype: str
"""
if isinstance(node, Const):
return f"C[{repr(node.value)}]"
elif isinstance(node, Var):
return _fingerprint_codegen_var(node, var_map, param_map, method_name, col_map)
elif isinstance(node, BinOp):
left_fp: str = _fingerprint_codegen_expr(node.left, var_map, param_map, method_name, col_map)
right_fp: str = _fingerprint_codegen_expr(node.right, var_map, param_map, method_name, col_map)
if node.op in ['+', '*'] and left_fp > right_fp:
left_fp, right_fp = right_fp, left_fp
else:
pass
return f"B[{node.op}|{left_fp}|{right_fp}]"
elif isinstance(node, UnOp):
operand_fp: str = _fingerprint_codegen_expr(node.operand, var_map, param_map, method_name, col_map)
return f"U[{node.op}|{operand_fp}]"
elif isinstance(node, Func):
arg_fp: str = _fingerprint_codegen_expr(node.arg, var_map, param_map, method_name, col_map)
return f"F[{node.op}|{arg_fp}]"
else:
return f"RAW[{type(node).__name__}|{str(node)}]"
def _build_residual_codegen_cache_key(equations: List[Expr],
var_map: Dict[int, int],
param_map: Dict[int, int],
method_name: str,
use_cse: bool,
offset: int,
n_variables: int,
n_parameters: int) -> str:
"""
Return the cache key of one residual kernel.
:param equations: Residual equations.
:type equations: List[Expr]
:param var_map: Runtime variable index map.
:type var_map: Dict[int, int]
:param param_map: Parameter index map.
:type param_map: Dict[int, int]
:param method_name: Discretization strategy name.
:type method_name: str
:param use_cse: Whether CSE is enabled.
:type use_cse: bool
:param offset: Output offset.
:type offset: int
:param n_variables: Number of runtime variables.
:type n_variables: int
:param n_parameters: Number of parameters.
:type n_parameters: int
:return: Deterministic cache key.
:rtype: str
"""
expr_fingerprints: List[str] = list()
equation_index: int = 0
while equation_index < len(equations):
expr_fingerprints.append(_fingerprint_codegen_expr(equations[equation_index], var_map, param_map, method_name, None))
equation_index += 1
payload: str = "|".join([
"residual",
method_name,
str(use_cse),
str(offset),
str(n_variables),
str(n_parameters),
str(len(equations)),
"::".join(expr_fingerprints),
])
return _build_codegen_cache_key(payload)
def _build_ad_codegen_cache_key(equations: List[Expr],
var_map: Dict[int, int],
param_map: Dict[int, int],
method_name: str,
use_cse: bool,
active_indices: set | None,
n_variables: int,
n_parameters: int) -> str:
"""
Return the cache key of one AD kernel.
:param equations: Residual equations.
:type equations: List[Expr]
:param var_map: Runtime variable index map.
:type var_map: Dict[int, int]
:param param_map: Parameter index map.
:type param_map: Dict[int, int]
:param method_name: Discretization strategy name.
:type method_name: str
:param use_cse: Whether CSE is enabled.
:type use_cse: bool
:param active_indices: Active seed indices or ``None``.
:type active_indices: set | None
:param n_variables: Number of runtime variables.
:type n_variables: int
:param n_parameters: Number of parameters.
:type n_parameters: int
:return: Deterministic cache key.
:rtype: str
"""
expr_fingerprints: List[str] = list()
active_index_list: List[str] = list()
equation_index: int = 0
while equation_index < len(equations):
expr_fingerprints.append(_fingerprint_codegen_expr(equations[equation_index], var_map, param_map, method_name, None))
equation_index += 1
if active_indices is None:
active_index_list = list()
else:
active_index_list = [str(index) for index in sorted(list(active_indices))]
payload: str = "|".join([
"ad",
method_name,
str(use_cse),
str(n_variables),
str(n_parameters),
str(len(equations)),
",".join(active_index_list),
"::".join(expr_fingerprints),
])
return _build_codegen_cache_key(payload)
def _build_matrix_codegen_cache_key(template_eq: Expr,
var_map: Dict[int, int],
param_map: Dict[int, int],
method_name: str,
col_map: Dict[str, int],
n_variables: int,
n_parameters: int) -> str:
"""
Return the cache key of one matrix-vectorized kernel.
:param template_eq: Structural template equation.
:type template_eq: Expr
:param var_map: Runtime variable index map.
:type var_map: Dict[int, int]
:param param_map: Parameter index map.
:type param_map: Dict[int, int]
:param method_name: Discretization strategy name.
:type method_name: str
:param col_map: Matrix-kernel column map.
:type col_map: Dict[str, int]
:param n_variables: Number of runtime variables.
:type n_variables: int
:param n_parameters: Number of parameters.
:type n_parameters: int
:return: Deterministic cache key.
:rtype: str
"""
expr_fingerprint: str = _fingerprint_codegen_expr(template_eq, var_map, param_map, method_name, col_map)
sorted_col_items: List[str] = list()
col_name: str
for col_name in sorted(col_map.keys()):
sorted_col_items.append(f"{col_name}:{col_map[col_name]}")
payload: str = "|".join([
"matrix",
method_name,
str(n_variables),
str(n_parameters),
",".join(sorted_col_items),
expr_fingerprint,
])
return _build_codegen_cache_key(payload)
# ==============================================================================
# 1. Discretization strategy (Numerical Methods)
# ==============================================================================
[docs]
class DiscretizationMethod(ABC):
"""Abstract base class for discretization strategies."""
__slots__ = ()
[docs]
@abstractmethod
def discretize(self, state_idx: int, h_var: str = 'h') -> str:
pass
[docs]
@abstractmethod
def discretize_dot(self, state_idx: int, h_var: str = 'h', seeds_var: str = 'seeds') -> str:
pass
[docs]
class TrapezoidalMethod(DiscretizationMethod):
__slots__ = ()
[docs]
def discretize(self, state_idx: int, h_var: str = 'h') -> str:
return f"((2.0/{h_var}) * (states[{state_idx}] - history[{state_idx}]) - d_history[{state_idx}])"
[docs]
def discretize_dot(self, state_idx: int, h_var: str = 'h', seeds_var: str = 'seeds') -> str:
return f"((2.0/{h_var}) * ({seeds_var}[{state_idx}]))"
[docs]
class BackwardEulerMethod(DiscretizationMethod):
__slots__ = ()
[docs]
def discretize(self, state_idx: int, h_var: str = 'h') -> str:
return f"((states[{state_idx}] - history[{state_idx}]) / {h_var})"
[docs]
def discretize_dot(self, state_idx: int, h_var: str = 'h', seeds_var: str = 'seeds') -> str:
return f"(({seeds_var}[{state_idx}]) / {h_var})"
[docs]
class BDF2Method(DiscretizationMethod):
__slots__ = ()
[docs]
def discretize(self, state_idx: int, h_var: str = 'h') -> str:
return f"((1.5*states[{state_idx}] - 2.0*history[{state_idx}] + 0.5*history2[{state_idx}]) / {h_var})"
[docs]
def discretize_dot(self, state_idx: int, h_var: str = 'h', seeds_var: str = 'seeds') -> str:
return f"((1.5*{seeds_var}[{state_idx}]) / {h_var})"
[docs]
class ContinuousMethod(DiscretizationMethod):
"""Strategy for continuous systems (RMS Small Signal). Does not discretize."""
__slots__ = ()
[docs]
def discretize(self, state_idx: int, h_var: str = 'h') -> str:
return f"dx[{state_idx}]"
[docs]
def discretize_dot(self, state_idx: int, h_var: str = 'h', seeds_var: str = 'seeds') -> str:
return "0.0"
# ==============================================================================
# 2. CSE (Common Subexpression Elimination) Analyzer (large systems)
# ==============================================================================
[docs]
class SubexpressionAnalyzer:
"""
Find and catalog subexpressions that appear multiple times.
OPTIMIZED: Uses memoization for complexity calc and string generation.
"""
__slots__ = ['threshold', 'expr_counts', 'expr_objects', 'expr_complexity', 'memo_complexity', 'memo_canonical',
'visited_traversal']
def __init__(self, threshold: int = 2)-> None:
self.threshold: int = threshold
self.expr_counts: Dict[str, int] = dict()
self.expr_objects: Dict[str, Expr] = dict()
self.expr_complexity: Dict[str, int] = dict()
# Caches
self.memo_complexity: Dict[int, int] = dict()
self.memo_canonical: Dict[int, str] = dict()
self.visited_traversal: Set[int] = set()
[docs]
def analyze(self, equations: List[Expr]) -> Dict[str, str]:
# Reset caches
self.expr_counts.clear()
self.expr_objects.clear()
self.expr_complexity.clear()
self.memo_complexity.clear()
self.memo_canonical.clear()
self.visited_traversal.clear()
for eq in equations:
self._count_subexpressions(eq)
candidates: List[Tuple[int, str]] = list()
for expr_hash, count in self.expr_counts.items():
if count >= self.threshold:
complexity = self.expr_complexity[expr_hash]
benefit = count * complexity
candidates.append((benefit, expr_hash))
candidates.sort(reverse=True)
temp_vars: Dict[str, str] = dict()
for i, (_, expr_hash) in enumerate(candidates):
temp_vars[expr_hash] = f"_t{i}"
return temp_vars
def _count_subexpressions(self, node: Expr)->None:
if isinstance(node, (Const, Var)):
return
complexity = self._calculate_complexity(node)
if complexity > 1:
expr_hash = self.hash_expr(node)
self.expr_counts[expr_hash] = self.expr_counts.get(expr_hash, 0) + 1
self.expr_objects[expr_hash] = node
self.expr_complexity[expr_hash] = complexity
else:
pass
if id(node) in self.visited_traversal:
return
self.visited_traversal.add(id(node))
if isinstance(node, BinOp):
self._count_subexpressions(node.left)
self._count_subexpressions(node.right)
elif isinstance(node, UnOp):
self._count_subexpressions(node.operand)
elif isinstance(node, Func):
self._count_subexpressions(node.arg)
def _calculate_complexity(self, node: Expr) -> int:
if id(node) in self.memo_complexity:
return self.memo_complexity[id(node)]
res: int
if isinstance(node, (Const, Var)):
res = 0
elif isinstance(node, BinOp):
op_cost = 2 if node.op in ['*', '/', '**'] else 1
res = op_cost + self._calculate_complexity(node.left) + \
self._calculate_complexity(node.right)
elif isinstance(node, UnOp):
res = 1 + self._calculate_complexity(node.operand)
elif isinstance(node, Func):
res = 5 + self._calculate_complexity(node.arg)
else:
res = 1
self.memo_complexity[id(node)] = res
return res
[docs]
def hash_expr(self, node: Expr) -> str:
canonical = self._expr_to_canonical_string(node)
return _stable_digest(canonical, length=12)
def _expr_to_canonical_string(self, node: Expr) -> str:
if id(node) in self.memo_canonical:
return self.memo_canonical[id(node)]
res: str
if isinstance(node, Const):
res = f"C{node.value}"
elif isinstance(node, Var):
if node.base_var is not None:
res = f"D{node.base_var.uid}"
else:
res = f"V{node.uid}"
elif isinstance(node, BinOp):
left = self._expr_to_canonical_string(node.left)
right = self._expr_to_canonical_string(node.right)
if node.op in ['+', '*']:
if left > right:
left, right = right, left
else:
pass
else:
pass
res = f"({left}{node.op}{right})"
elif isinstance(node, UnOp):
operand = self._expr_to_canonical_string(node.operand)
res = f"{node.op}{operand}"
elif isinstance(node, Func):
res = f"{node.op}({self._expr_to_canonical_string(node.arg)})"
else:
res = str(node)
self.memo_canonical[id(node)] = res
return res
# ==============================================================================
# 3. Conversion to string based on visitor pattern (sequentiality)
# ==============================================================================
[docs]
class SymbolicToPythonVisitor:
__slots__ = ['var_map', 'param_map', 'method', 'cse_map', 'analyzer', 'in_cse_def', '_str_cache']
OP_PRECEDENCE = {'+': 10, '-': 10, '*': 20, '/': 20, '**': 30}
def __init__(self, var_map: Dict[int, int], param_map: Dict[int, int], method: DiscretizationMethod) -> None:
self.var_map = var_map
self.param_map = param_map or dict()
self.method = method
self.cse_map: Dict[str, str] = dict()
self.analyzer = None
self.in_cse_def = False
self._str_cache: Dict[Tuple[int, int], str] = dict()
def _prec(self, node: Expr) -> int:
"""Return precedence of a node. Non-operators get 'infinite' precedence."""
# BinOp nodes define precedence; everything else we treat as atomic.
if isinstance(node, BinOp):
return self.OP_PRECEDENCE.get(node.op, 0)
return 10 ** 9
def _maybe_parenthesize(self,
code: str,
child_node: Expr,
parent_op: str,
parent_prec: int,
side: str) -> str:
"""
Decide whether to wrap child expression in parentheses.
Rules:
- If child's precedence < parent's precedence => need parentheses.
- If child's precedence == parent's precedence:
* For non-associative ops (-, /): right child must be parenthesized.
* For exponent (**), which is right-associative: left child must be parenthesized when equal precedence.
* For associative ops (+, *): no parentheses needed.
"""
child_prec = self._prec(child_node)
# Lower precedence always needs parentheses
if child_prec < parent_prec:
return f"({code})"
# Equal precedence: depends on operator and side
if child_prec == parent_prec:
if parent_op in ('-', '/'):
# a - (b - c) and a / (b / c) must keep parentheses on right
if side == 'right':
return f"({code})"
if parent_op == '**':
# (a ** b) ** c must keep parentheses on left (since ** is right-associative)
if side == 'left':
return f"({code})"
# + and * are associative enough here -> no parentheses
return code
[docs]
def visit(self, node: Expr, precedence: int = 0) -> str:
"""
Dispatches node processing using explicit type matching to avoid reflection.
:param node: The symbolic expression node to visit.
:type node: Expr
:param precedence: Operator precedence level.
:type precedence: int
:return: Python code string representation of the node.
:rtype: str
"""
# 1. CSE (Common Subexpression Elimination) Check
if self.analyzer is not None and self.cse_map is not None and not self.in_cse_def:
if not isinstance(node, (Var, Const)):
h: str = self.analyzer.hash_expr(node)
if h in self.cse_map:
return self.cse_map[h]
else:
pass
else:
pass
else:
pass
# 2. Cache Check
cache_key: Tuple[int, int] = (id(node), precedence)
if cache_key in self._str_cache:
return self._str_cache[cache_key]
else:
pass
result: str = ""
match node:
case BinOp():
result = self.visit_binop(node, precedence)
case UnOp():
result = self.visit_unop(node, precedence)
case Const():
result = self.visit_const(node, precedence)
case Var():
result = self.visit_var(node, precedence)
case Func():
result = self.visit_func(node, precedence)
case _:
result = self.generic_visit(node, precedence)
# 5. Store and Return
self._str_cache[cache_key] = result
return result
[docs]
def generic_visit(self, node: Expr, _: int) -> str:
raise NotImplementedError(f"Node not supported: {type(node)}")
# def visit_binop(self, node: BinOp, prec: int) -> str:
# curr_prec = self.OP_PRECEDENCE.get(node.op, 0)
# l = self.visit(node.left, curr_prec)
# r = self.visit(node.right, curr_prec)
# code = f"{l} {node.op} {r}"
# return f"({code})" if curr_prec < prec else code
[docs]
def visit_binop(self, node: 'BinOp', prec: int) -> str:
"""
Emit Python code for a binary operation with correct parentheses.
Important:
- We do NOT just compare prec numerically in a naive way; we also handle
associativity for '-', '/', and '**' to avoid sign/structure bugs.
"""
op = node.op
curr_prec = self.OP_PRECEDENCE.get(op, 0)
# Visit children using current operator precedence as the "context"
l_code = self.visit(node.left, curr_prec)
r_code = self.visit(node.right, curr_prec)
# Parenthesize children if needed (precedence + associativity)
l_code = self._maybe_parenthesize(l_code, node.left, op, curr_prec, side='left')
r_code = self._maybe_parenthesize(r_code, node.right, op, curr_prec, side='right')
code = f"{l_code} {op} {r_code}"
# If parent context has higher precedence, wrap the whole expression
return f"({code})" if curr_prec < prec else code
[docs]
def visit_unop(self, node: UnOp, _precedence: int) -> str:
return f"{node.op}{self.visit(node.operand, 100)}"
[docs]
def visit_const(self, node: Const, _precedence: int) -> str:
return str(node.value)
[docs]
def visit_var(self, node: Var, prec: int) -> str:
if node.uid in self.var_map:
return f"states[{self.var_map[node.uid]}]"
else:
pass
if node.uid in self.param_map:
return f"params[{self.param_map[node.uid]}]"
else:
pass
if node.base_var is not None:
return self.visit_diffvar(node, prec)
else:
pass
if node.name == 'h':
return 'h'
raise ValueError(f"Variable '{node.name}' (UID: {node.uid}) Not mapped in compiler.")
[docs]
def visit_diffvar(self, node: Var, prec: int) -> str:
# If base_var is also a DiffVar, recursively discretize it
if node.base_var.base_var is not None:
# Recursively get discretized base (e.g., dx from x)
base_discretized = self.visit_diffvar(node.base_var, prec)
# Apply discretization to the discretized base using d_history
# For Backward Euler: d2x = (dx - dx_history) / h
# Use base_var's origin since base_var is itself a DiffVar
base_idx = self.var_map[node.base_var.origin_var.uid]
term = f"(({base_discretized} - d_history[{base_idx}]) / h)"
else:
# Base is a regular state variable
base_uid = node.base_var.uid
if base_uid not in self.var_map:
raise ValueError(f"Base var '{node.base_var.name}' (UID: {base_uid})"
f" for DiffVar does not exist in states.")
term = self.method.discretize(self.var_map[base_uid])
return f"({term})" if prec > 10 else term
[docs]
def visit_func(self, node: Func, _precedence: int) -> str:
arg = self.visit(node.arg, 0)
op = node.op
if op == 'heaviside':
return f"_heaviside({arg})"
if op == 'abs':
return f"np.abs({arg})"
return f"np.{op}({arg})"
# =============================================================================================
# 4. Solution for large-scale algebraic frameworks, based on differential automatic composition
# =============================================================================================
[docs]
class ADVisitor(SymbolicToPythonVisitor):
__slots__ = ['seeds_var', 'active_indices', 'cse_has_dot']
def __init__(self,
var_map: Dict[int, int],
param_map: Dict[int, int],
method: DiscretizationMethod,
seeds_var: str = 'seeds',
active_indices: set | None = None) -> None:
super().__init__(var_map, param_map, method)
self.seeds_var = seeds_var
self.active_indices = active_indices
self.cse_has_dot: Set[str] = set()
self._str_cache: Dict[Tuple[int, int], Tuple[str, str]] = dict()
[docs]
def generic_visit(self, node: Expr, _: int) -> Tuple[str, str]:
raise NotImplementedError(f"Node not supported in AD: {type(node)}")
[docs]
def visit(self, node: Expr, precedence: int = 0) -> Tuple[str, str]:
"""
Dispatches node processing for Automatic Differentiation using explicit type matching.
"""
# 1. CSE Check
if self.analyzer is not None and self.cse_map is not None and not self.in_cse_def:
if not isinstance(node, (Var, Const)):
h_hash: str = self.analyzer.hash_expr(node)
if h_hash in self.cse_map:
var_name = self.cse_map[h_hash]
dot_name = f"{var_name}_d" if var_name in self.cse_has_dot else "0.0"
return var_name, dot_name
else:
pass
else:
pass
else:
pass
# 2. Cache Check (Critical for AD speed)
cache_key: Tuple[int, int] = (id(node), precedence)
if cache_key in self._str_cache:
return self._str_cache[cache_key]
else:
pass
res: Tuple[str, str] = ("", "")
match node:
case BinOp():
res = self.visit_binop(node, precedence)
case UnOp():
res = self.visit_unop(node, precedence)
case Const():
res = self.visit_const(node, precedence)
case Var():
res = self.visit_var(node, precedence)
case Func():
res = self.visit_func(node, precedence)
case _:
# Fallback for generic or unsupported nodes
res = self.generic_visit(node, precedence)
# 5. Store and Return
self._str_cache[cache_key] = res
return res
[docs]
def visit_const(self, node: Const, _: int) -> Tuple[str, str]:
return str(node.value), "0.0"
[docs]
def visit_var(self, node: Var, _precedence: int) -> Tuple[str, str]:
# 1. Standard mapping for states
if node.uid in self.var_map:
i: int = self.var_map[node.uid]
val: str = f"states[{i}]"
# Seed value for the specific column being differentiated
if self.active_indices is not None:
seed_val: str = "1.0" if i in self.active_indices else "0.0"
else:
seed_val: str = f"{self.seeds_var}[{i}]"
return val, seed_val
else:
pass
# 2. Parameter mapping (Parameters have a seed of 0.0)
if node.uid in self.param_map:
return f"params[{self.param_map[node.uid]}]", "0.0"
else:
pass
if node.base_var is not None:
return self.visit_diffvar(node, 100)
else:
pass
raise ValueError(f"Var '{node.name}' (UID: {node.uid}) Not mapped in ADVisitor.")
[docs]
def visit_diffvar(self, node: Var, prec: int) -> Tuple[str, str]:
# If base_var is also a DiffVar, recursively discretize it
if node.base_var.base_var is not None:
# Recursively get discretized base (value, dot)
base_val, base_dot = self.visit_diffvar(node.base_var, prec)
# Use base_var's origin since base_var is itself a DiffVar
base_idx = self.var_map[node.base_var.origin_var.uid]
if isinstance(self.method, BackwardEulerMethod):
# d2x = (dx - dx_history) / h
val = f"(({base_val} - d_history[{base_idx}]) / h)"
dot = "0.0" if base_dot == "0.0" else f"({base_dot}/h)"
elif isinstance(self.method, TrapezoidalMethod):
# d2x = (2/h)*(dx - dx_history) - d2x_history
val = f"((2.0/h)*({base_val} - d_history[{base_idx}]) - d2_history[{base_idx}])"
dot = "0.0" if base_dot == "0.0" else f"(2.0*{base_dot}/h)"
elif isinstance(self.method, BDF2Method):
# Not implemented for nested diff vars with BDF2
raise NotImplementedError("Nested DiffVar not supported with BDF2")
else:
raise NotImplementedError
else:
# Base is a regular state variable
base_uid = node.base_var.uid
i = self.var_map[base_uid]
if self.active_indices is not None:
seed = "1.0" if i in self.active_indices else "0.0"
else:
seed = f"{self.seeds_var}[{i}]"
if isinstance(self.method, BackwardEulerMethod):
val = f"((states[{i}] - history[{i}]) / h)"
dot = "0.0" if seed == "0.0" else ("(1.0/h)" if seed == "1.0" else f"({seed}/h)")
elif isinstance(self.method, TrapezoidalMethod):
val = f"((2.0/h)*(states[{i}] - history[{i}]) - d_history[{i}])"
dot = "0.0" if seed == "0.0" else ("(2.0/h)" if seed == "1.0" else f"(2.0*{seed}/h)")
elif isinstance(self.method, BDF2Method):
val = f"((1.5*states[{i}] - 2.0*history[{i}] + 0.5*history2[{i}]) / h)"
dot = "0.0" if seed == "0.0" else ("(1.5/h)" if seed == "1.0" else f"(1.5*{seed}/h)")
else:
raise NotImplementedError
return (f"({val})", f"({dot})") if prec > 10 else (val, dot)
[docs]
def visit_binop(self, node: BinOp, prec: int) -> Tuple[str, str]:
curr_prec = self.OP_PRECEDENCE.get(node.op, 0)
lv, ld = self.visit(node.left, curr_prec)
rv, rd = self.visit(node.right, curr_prec)
op = node.op
v = ""
if op == '+':
v = rv if _is_zero(lv) else (lv if _is_zero(rv) else f"{lv} + {rv}")
elif op == '-':
v = lv if _is_zero(rv) else (f"-{rv}" if _is_zero(lv) else f"{lv} - {rv}")
elif op == '*':
v = "0.0" if (_is_zero(lv) or _is_zero(rv)) else (
rv if _is_one(lv) else (lv if _is_one(rv) else f"{lv} * {rv}"))
elif op == '/':
v = "0.0" if _is_zero(lv) else (lv if _is_one(rv) else f"{lv} / {rv}")
elif op == '**':
v = "1.0" if _is_zero(rv) else (lv if _is_one(rv) else (
f"({lv})**({node.right.value})" if isinstance(node.right, Const) else f"({lv})**({rv})"))
d = "0.0"
if op == '+':
d = rd if _is_zero(ld) else (ld if _is_zero(rd) else f"({ld} + {rd})")
elif op == '-':
d = f"(-{rd})" if _is_zero(ld) else (ld if _is_zero(rd) else f"({ld} - {rd})")
elif op == '*':
t1 = "0.0"
if not _is_zero(ld): t1 = rv if _is_one(ld) else (ld if _is_one(rv) else f"({ld})*({rv})")
t2 = "0.0"
if not _is_zero(rd): t2 = lv if _is_one(rd) else (rd if _is_one(lv) else f"({lv})*({rd})")
d = t2 if t1 == "0.0" else (t1 if t2 == "0.0" else f"({t1} + {t2})")
elif op == '/':
if _is_zero(lv):
d = "0.0"
elif _is_zero(ld) and _is_zero(rd):
d = "0.0"
elif _is_zero(rd):
d = ld if _is_one(rv) else f"({ld}) / ({rv})"
else:
d = f"(({ld})*({rv}) - ({lv})*({rd})) / (({rv})*({rv}))"
elif op == '**':
if not (_is_zero(ld) and _is_zero(rd)):
if isinstance(node.right, Const):
val = float(node.right.value)
if val == 1.0:
d = ld
elif val == 2.0:
d = f"2.0*({lv})*({ld})"
else:
d = f"({val})*(({lv})**({val - 1}))*({ld})"
else:
d = f"(({v}) * ( ({rd})*np.log({lv}) + ({rv})*({ld})/({lv}) ))"
return (f"({v})", f"({d})") if curr_prec < prec else (v, d)
[docs]
def visit_unop(self, node: UnOp, _: int) -> Tuple[str, str]:
ov, od = self.visit(node.operand, 100)
if _is_zero(od):
s = f"{node.op}({ov})"
if node.op == '+': s = ov
return s, "0.0"
if node.op == '+': return f"(+({ov}))", f"(+({od}))"
if node.op == '-': return f"(-({ov}))", f"(-({od}))"
raise NotImplementedError
# def visit_func(self, node: Func, _: int) -> Tuple[str, str]:
# av, ad = self.visit(node.arg, 0)
# eps_str = "1e-15"
#
# if _is_zero(ad):
# if node.op == 'heaviside': return f"_heaviside({av})", "0.0"
# if node.op == 'abs': return f"np.abs({av})", "0.0"
# return f"np.{node.op}({av})", "0.0"
#
# if node.op == 'sin': return f"np.sin({av})", f"(np.cos({av})*({ad}))"
# if node.op == 'cos': return f"np.cos({av})", f"(-np.sin({av})*({ad}))"
#
# if node.op == 'log':
# return f"np.log({av})", f"({ad})/({av} + {eps_str})"
# if node.op == 'sqrt':
# return f"np.sqrt({av})", f"0.5*({ad})/(np.sqrt({av}) + {eps_str})"
#
# raise NotImplementedError(f"Function {node.op} not supported in AD")
[docs]
def visit_func(self, node: Func, _: int) -> Tuple[str, str]:
av, ad = self.visit(node.arg, 0)
eps_str = "1e-15"
# Support non-smooth functions with zero AD derivative
if node.op == 'heaviside':
return f"_heaviside({av})", "0.0"
if node.op == 'abs':
if _is_zero(ad):
return f"np.abs({av})", "0.0"
return f"np.abs({av})", f"(np.sign({av})*({ad}))"
if _is_zero(ad):
return f"np.{node.op}({av})", "0.0"
if node.op == 'sin':
return f"np.sin({av})", f"(np.cos({av})*({ad}))"
if node.op == 'cos':
return f"np.cos({av})", f"(-np.sin({av})*({ad}))"
if node.op == 'exp':
return f"np.exp({av})", f"(np.exp({av})*({ad}))"
if node.op == 'log':
return f"np.log({av})", f"({ad})/({av} + {eps_str})"
if node.op == 'sqrt':
return f"np.sqrt({av})", f"0.5*({ad})/(np.sqrt({av}) + {eps_str})"
raise NotImplementedError(f"Function {node.op} not supported in AD")
# ==============================================================================
# 5. Compiler (sym-num)(discretized)
# ==============================================================================
[docs]
class EquationCompiler:
"""Main interface for compiling symbolic equations into executable functions."""
__slots__ = ['variables_objs', 'parameters_objs', 'strategy', 'var_map', 'param_map', 'visitor']
METHODS = {
DynamicIntegrationMethod.DaeTrapezoidal: TrapezoidalMethod(),
DynamicIntegrationMethod.DaeBackEuler: BackwardEulerMethod(),
DynamicIntegrationMethod.DaeBDF2: BDF2Method(),
DynamicIntegrationMethod.DaeContinuous: ContinuousMethod()
}
def __init__(self,
variables: List[Var],
parameters: List[Var] | None = None,
method: DynamicIntegrationMethod = DynamicIntegrationMethod.DaeTrapezoidal) -> None:
self.variables_objs = variables
self.parameters_objs = parameters if parameters is not None else list()
if method not in self.METHODS:
raise ValueError(f"Method '{method}' Unknown. Options: {list(self.METHODS.keys())}")
self.strategy = self.METHODS[method]
self.var_map = {v.uid: i for i, v in enumerate(self.variables_objs)}
self.param_map = {p.uid: i for i, p in enumerate(self.parameters_objs)}
self.visitor = SymbolicToPythonVisitor(self.var_map, self.param_map, self.strategy)
[docs]
def compile(self, equations: List[Expr], func_name: str = "step_fn", use_cse: bool = True,
offset: int = 0, inplace: bool = False) -> Callable:
lines: List[str] = list()
cache_key: str = _build_equation_compiler_residual_cache_key(
equations,
self.var_map,
self.param_map,
type(self.strategy).__name__,
use_cse,
offset,
inplace,
)
cache_entry: GeneratedKernelCacheEntry | None = GENERATED_KERNEL_CACHE.get_entry("residual", cache_key)
if cache_entry is None:
pass
else:
return cache_entry.get_python_function()
canonical_func_name: str = f"eq_step_{cache_key[:16]}"
# Header depends on mode
if inplace:
lines.append(f"def {canonical_func_name}(states, params, history, d_history, h, residuals, history2=None):")
else:
lines.append(f"def {canonical_func_name}(states, params, history, d_history, h, history2=None):")
# --- CSE Block ---
if use_cse:
analyzer = SubexpressionAnalyzer()
cse_map = analyzer.analyze(equations)
self.visitor.analyzer = analyzer
self.visitor.cse_map = cse_map
self.visitor._str_cache = dict()
if cse_map:
sorted_cse = sorted(cse_map.items(), key=_get_cse_sort_key)
for expr_hash, var_name in sorted_cse:
expr_obj = analyzer.expr_objects[expr_hash]
self.visitor.in_cse_def = True
expr_code = self.visitor.visit(expr_obj)
self.visitor.in_cse_def = False
lines.append(f" {var_name} = {expr_code}")
else:
pass
else:
pass
# --- Body ---
if not inplace:
lines.append(f" residuals = np.zeros({len(equations)}, dtype=np.float64)")
else:
pass
for i, eq in enumerate(equations):
code = self.visitor.visit(eq)
# Apply offset if inplace
target_idx = i + offset if inplace else i
lines.append(f" residuals[{target_idx}] = {code}")
if not inplace:
lines.append(" return residuals")
else:
pass
# Clean visitor
self.visitor.cse_map = dict()
self.visitor.analyzer = None
self.visitor._str_cache = dict()
full_source = "\n".join(lines)
py_func: Callable = _compile_to_file(full_source, canonical_func_name)
GENERATED_KERNEL_CACHE.set_entry("residual", cache_key, GeneratedKernelCacheEntry(py_func, None))
return py_func
[docs]
def compile_ad_kernel(self, equations: List[Expr], func_name: str = "ad_step", use_cse: bool = True,
active_indices: set | None = None) -> Callable:
adv = ADVisitor(self.var_map, self.param_map, self.strategy, seeds_var='seeds', active_indices=active_indices)
cache_key: str = _build_equation_compiler_ad_cache_key(
equations,
self.var_map,
self.param_map,
type(self.strategy).__name__,
use_cse,
active_indices,
)
cache_entry: GeneratedKernelCacheEntry | None = GENERATED_KERNEL_CACHE.get_entry("ad", cache_key)
if cache_entry is None:
pass
else:
return cache_entry.get_python_function()
canonical_func_name: str = f"eq_ad_{cache_key[:16]}"
lines: List[str] = list()
lines.append(f"def {canonical_func_name}(states, seeds, params, history, d_history, h, history2=None):")
if use_cse:
analyzer = SubexpressionAnalyzer()
cse_map = analyzer.analyze(equations)
adv.analyzer = analyzer
adv.cse_map = cse_map
adv._str_cache = dict()
if cse_map:
sorted_cse = sorted(cse_map.items(), key=_get_cse_sort_key)
for expr_hash, var_name in sorted_cse:
expr_obj = analyzer.expr_objects[expr_hash]
adv.in_cse_def = True
expr_val, expr_dot = adv.visit(expr_obj)
adv.in_cse_def = False
lines.append(f" {var_name} = {expr_val}")
if expr_dot != "0.0":
lines.append(f" {var_name}_d = {expr_dot}")
adv.cse_has_dot.add(var_name)
else:
pass
else:
pass
else:
pass
n: int = len(equations) # RULE: Local type hint
lines.append(f" jvp = np.zeros({n}, dtype=np.float64)")
for i, eq in enumerate(equations):
val, dot = adv.visit(eq)
if dot != "0.0":
lines.append(f" jvp[{i}] = {dot}")
else:
pass
lines.append(" return jvp")
adv.cse_map = dict()
adv.analyzer = None
adv._str_cache = dict()
full_source: str = "\n".join(lines)
py_func: Callable = _compile_to_file(full_source, canonical_func_name)
GENERATED_KERNEL_CACHE.set_entry("ad", cache_key, GeneratedKernelCacheEntry(py_func, None))
return py_func
# ==============================================================================
# 6. Matrix Vectorized Compiler (HPC Optimized - No Dictionaries)
# ==============================================================================
[docs]
class MatrixVectorizedVisitor(SymbolicToPythonVisitor):
"""
HPC-optimized visitor. Instead of performing dictionary lookups (slow),
it accesses state variable indices directly from a mapping matrix (fast).
Generates code such as: states[indices[:, 5]]
"""
__slots__ = ['col_map']
def __init__(self,
var_map: Dict[int, int],
param_map: Dict[int, int],
method: DiscretizationMethod,
col_map: Dict[str, int]) -> None:
super().__init__(var_map, param_map, method)
self.col_map = col_map # Maps variable_name -> column_index (int)
[docs]
def visit_var(self, node: Var, _precedence: int) -> str:
if node.uid in self.param_map:
return f"params[{self.param_map[node.uid]}]"
else:
pass
if node.uid in self.var_map:
col_idx = self.col_map[node.name]
return f"states[indices[:, {col_idx}]]"
else:
pass
# EMT Support: If the node is a derivative, call visit_diffvar
if node.base_var is not None:
return self.visit_diffvar(node, 100)
else:
pass
raise ValueError(f"Var '{node.name}' (UID: {node.uid}) Not mapped in matrix visitor.")
[docs]
def visit_diffvar(self, node: Var, prec: int) -> str:
# If base_var is also a DiffVar, recursively discretize it
if node.base_var.base_var is not None:
# Use base_var's origin for history lookup since base_var is a DiffVar
base_origin_name = node.base_var.origin_var.name
col_idx = self.col_map[base_origin_name]
idx_access = f"indices[:, {col_idx}]"
# Recursively get discretized base
base_discretized = self.visit_diffvar(node.base_var, prec)
# Apply discretization using d_history
if isinstance(self.method, BackwardEulerMethod):
term = f"(({base_discretized} - d_history[{idx_access}]) / h)"
elif isinstance(self.method, TrapezoidalMethod):
term = f"((2.0/h)*({base_discretized} - d_history[{idx_access}]) - d2_history[{idx_access}])"
else:
raise NotImplementedError(f"Nested DiffVar not supported with {type(self.method)}")
else:
# Base is a regular state variable - use standard discretization
# Get origin variable name for column mapping
origin_name = node.origin_var.name
col_idx = self.col_map[origin_name]
idx_access = f"indices[:, {col_idx}]"
if isinstance(self.method, TrapezoidalMethod):
term = f"((2.0/h) * (states[{idx_access}] - history[{idx_access}]) - d_history[{idx_access}])"
elif isinstance(self.method, BackwardEulerMethod):
term = f"((states[{idx_access}] - history[{idx_access}]) / h)"
elif isinstance(self.method, BDF2Method):
term = f"((1.5*states[{idx_access}] - 2.0*history[{idx_access}] + 0.5*history2[{idx_access}]) / h)"
else:
raise NotImplementedError(f"Method {type(self.method)} not supported in matrix vectorization.")
return f"({term})" if prec > 10 else term
[docs]
class MatrixVectorizedCompiler(EquationCompiler):
"""
Compiler utilizing the Matrix Vectorized Visitor.
Generates kernels that accept 'indices' as a 2D int32 array for batch processing.
"""
__slots__ = []
[docs]
def compile_matrix_kernel(self, template_eq: Expr, func_name: str, template_vars: List[Var]) -> Callable:
# 1. Create a deterministic column map (Var index within the kernel)
col_map: Dict[str, int] = dict()
for i, v in enumerate(template_vars):
name = v.base_var.name if v.base_var is not None else v.name
col_map[name] = i
cache_key: str = _build_equation_compiler_matrix_cache_key(
template_eq,
self.var_map,
self.param_map,
type(self.strategy).__name__,
col_map,
)
cache_entry: GeneratedKernelCacheEntry | None = GENERATED_KERNEL_CACHE.get_entry("matrix", cache_key)
if cache_entry is None:
pass
else:
return cache_entry.get_python_function()
canonical_func_name: str = f"eq_matrix_{cache_key[:16]}"
# 2. Generate source code using the Matrix Visitor
visitor = MatrixVectorizedVisitor(self.var_map, self.param_map, self.strategy, col_map)
rhs_code = visitor.visit(template_eq)
# 3. Package into a function (indices is now a 2D int32 array)
lines: List[str] = list()
lines.append(f"def {canonical_func_name}(states, params, history, d_history, h, indices, history2=None):")
lines.append(f" # Vectorized Matrix Kernel (No Dicts)")
lines.append(f" residuals = {rhs_code}")
lines.append(f" return residuals")
full_source = "\n".join(lines)
# 4. Compile to file (enabling Numba's persistent cache)
py_func: Callable = _compile_to_file(full_source, canonical_func_name)
GENERATED_KERNEL_CACHE.set_entry("matrix", cache_key, GeneratedKernelCacheEntry(py_func, None))
return py_func
[docs]
class EagerKernelKind(Enum):
"""
Enumeration of eager kernel application binary interfaces.
The eager compiler needs an explicit kernel family because the Python source
and the Numba signature differ between residual kernels, AD kernels and
vectorized matrix kernels.
"""
__slots__ = ()
Residual = "residual"
AutomaticDifferentiation = "automatic_differentiation"
MatrixVectorized = "matrix_vectorized"
def _get_cse_sort_index(cse_item: Tuple[str, str]) -> int:
"""
Return the numeric order of a generated CSE temporary.
:param cse_item: Pair ``(expr_hash, temp_name)`` produced by the analyzer.
:type cse_item: Tuple[str, str]
:return: Numeric suffix of the temporary variable.
:rtype: int
"""
temp_name: str = cse_item[1]
numeric_part: str = temp_name[2:]
return int(numeric_part)
def _reset_symbolic_codegen_state(visitor: "SymbolicToPythonVisitor | ADVisitor") -> None:
"""
Reset visitor caches before and after eager code generation.
:param visitor: Symbolic or AD visitor instance.
:type visitor: SymbolicToPythonVisitor | ADVisitor
:return: None
:rtype: None
"""
visitor.cse_map = dict()
visitor.analyzer = None
visitor.in_cse_def = False
visitor._str_cache = dict()
if isinstance(visitor, ADVisitor):
visitor.cse_has_dot = set()
else:
pass
[docs]
class EagerEquationCompiler(EquationCompiler):
"""
Strict eager compiler that emits in-place kernels with explicit signatures.
The generated functions do not allocate or return temporary arrays. Instead,
they receive a caller-provided ``data_out`` vector and write their results in
place. This removes Numba's type inference cost at first call and reduces the
runtime allocation pressure inside Newton iterations.
"""
__slots__ = []
[docs]
def generate_signature(
self,
kernel_tpe: EagerKernelKind,
n_variables: int,
n_parameters: int,
nnz: int,
with_history2: bool = True,
) -> object:
"""
Build the eager Numba signature for a generated kernel.
:param kernel_tpe: Kernel family to be compiled eagerly.
:type kernel_tpe: EagerKernelKind
:param n_variables: Number of runtime variables of the DAE system.
:type n_variables: int
:param n_parameters: Number of parameters visible to the kernel.
:type n_parameters: int
:param nnz: Number of output values written by the kernel.
:type nnz: int
:param with_history2: Whether the generated ABI includes ``history2``.
:type with_history2: bool
:return: Numba eager signature object.
:rtype: object
"""
float64_vector_tpe: object = types.Array(types.float64, 1, "C")
int32_matrix_tpe: object = types.Array(types.int32, 2, "C", readonly=True)
if n_variables <= 0:
raise ValueError("n_variables must be greater than zero")
else:
pass
if n_parameters < 0:
raise ValueError("n_parameters must be greater than or equal to zero")
else:
pass
if nnz < 0:
raise ValueError("nnz must be greater than or equal to zero")
else:
pass
if with_history2:
pass
else:
raise ValueError(
"EagerEquationCompiler requires an explicit history2 array to keep a strict ABI."
)
signature_tpe: object
if kernel_tpe == EagerKernelKind.Residual:
signature_tpe = types.void(
float64_vector_tpe,
float64_vector_tpe,
float64_vector_tpe,
float64_vector_tpe,
types.float64,
float64_vector_tpe,
float64_vector_tpe,
)
elif kernel_tpe == EagerKernelKind.AutomaticDifferentiation:
signature_tpe = types.void(
float64_vector_tpe,
float64_vector_tpe,
float64_vector_tpe,
float64_vector_tpe,
float64_vector_tpe,
types.float64,
float64_vector_tpe,
float64_vector_tpe,
)
elif kernel_tpe == EagerKernelKind.MatrixVectorized:
signature_tpe = types.void(
float64_vector_tpe,
float64_vector_tpe,
float64_vector_tpe,
float64_vector_tpe,
types.float64,
int32_matrix_tpe,
float64_vector_tpe,
float64_vector_tpe,
)
else:
raise ValueError(f"Unsupported eager kernel kind: {kernel_tpe}")
return signature_tpe
[docs]
def compile(
self,
equations: List[Expr],
func_name: str = "step_fn",
use_cse: bool = True,
offset: int = 0,
inplace: bool = True,
) -> Tuple[Callable, object]:
"""
Compile residual equations into an in-place eager kernel.
:param equations: Residual equations to be emitted.
:type equations: List[Expr]
:param func_name: Deterministic function name used in the source file.
:type func_name: str
:param use_cse: Whether to emit common subexpressions.
:type use_cse: bool
:param offset: Output offset inside ``data_out``.
:type offset: int
:param inplace: Compatibility argument that must stay ``True``.
:type inplace: bool
:return: Pair ``(python_function, eager_signature)``.
:rtype: Tuple[Callable, object]
"""
lines: List[str] = list()
py_func: Callable
signature_tpe: object
full_source: str
cache_key: str
cache_entry: GeneratedKernelCacheEntry | None
canonical_func_name: str
cache_key: str
cache_entry: GeneratedKernelCacheEntry | None
canonical_func_name: str
if inplace:
pass
else:
raise ValueError("EagerEquationCompiler only supports inplace residual kernels")
cache_key = _build_residual_codegen_cache_key(
equations,
self.var_map,
self.param_map,
type(self.strategy).__name__,
use_cse,
offset,
len(self.variables_objs),
len(self.parameters_objs),
)
cache_entry = GENERATED_KERNEL_CACHE.get_entry("residual", cache_key)
if cache_entry is None:
pass
else:
return cache_entry.get_python_function(), cache_entry.get_signature_tpe()
_reset_symbolic_codegen_state(self.visitor)
canonical_func_name = f"eager_residual_{cache_key[:16]}"
# The eager ABI receives an explicit output vector and writes residuals in place.
lines.append(f"def {canonical_func_name}(states, params, history, d_history, h, data_out, history2):")
# Common subexpressions are emitted first so the body performs direct writes only once.
if use_cse:
analyzer: SubexpressionAnalyzer = SubexpressionAnalyzer()
cse_map: Dict[str, str] = analyzer.analyze(equations)
sorted_cse: List[Tuple[str, str]] = sorted(cse_map.items(), key=_get_cse_sort_index)
self.visitor.analyzer = analyzer
self.visitor.cse_map = cse_map
self.visitor._str_cache = dict()
for expr_hash, var_name in sorted_cse:
expr_obj: Expr = analyzer.expr_objects[expr_hash]
expr_code: str
self.visitor.in_cse_def = True
expr_code = self.visitor.visit(expr_obj)
self.visitor.in_cse_def = False
lines.append(f" {var_name} = {expr_code}")
else:
pass
# Each equation writes directly into the preallocated output buffer.
for i, eq in enumerate(equations):
expr_code = self.visitor.visit(eq)
lines.append(f" data_out[{offset + i}] = {expr_code}")
full_source = "\n".join(lines)
py_func = _compile_to_file(full_source, canonical_func_name)
signature_tpe = self.generate_signature(
kernel_tpe=EagerKernelKind.Residual,
n_variables=len(self.variables_objs),
n_parameters=len(self.parameters_objs),
nnz=offset + len(equations),
with_history2=True,
)
_reset_symbolic_codegen_state(self.visitor)
GENERATED_KERNEL_CACHE.set_entry("residual", cache_key, GeneratedKernelCacheEntry(py_func, signature_tpe))
return py_func, signature_tpe
[docs]
def compile_ad_kernel(
self,
equations: List[Expr],
func_name: str = "ad_step",
use_cse: bool = True,
active_indices: set | None = None,
) -> Tuple[Callable, object]:
"""
Compile a sparse forward-mode AD kernel into an in-place eager function.
:param equations: Residual equations to differentiate.
:type equations: List[Expr]
:param func_name: Deterministic function name used in the source file.
:type func_name: str
:param use_cse: Whether to emit common subexpressions.
:type use_cse: bool
:param active_indices: Colored column subset activated in the current JVP sweep.
:type active_indices: set | None
:return: Pair ``(python_function, eager_signature)``.
:rtype: Tuple[Callable, object]
"""
adv: ADVisitor = ADVisitor(
self.var_map,
self.param_map,
self.strategy,
seeds_var='seeds',
active_indices=active_indices,
)
lines: List[str] = list()
py_func: Callable
signature_tpe: object
full_source: str
cache_key: str
cache_entry: GeneratedKernelCacheEntry | None
canonical_func_name: str
cache_key = _build_ad_codegen_cache_key(
equations,
self.var_map,
self.param_map,
type(self.strategy).__name__,
use_cse,
active_indices,
len(self.variables_objs),
len(self.parameters_objs),
)
cache_entry = GENERATED_KERNEL_CACHE.get_entry("ad", cache_key)
if cache_entry is None:
pass
else:
return cache_entry.get_python_function(), cache_entry.get_signature_tpe()
_reset_symbolic_codegen_state(adv)
canonical_func_name = f"eager_ad_{cache_key[:16]}"
# The eager AD ABI also writes into a caller-owned buffer to avoid JVP allocations.
lines.append(f"def {canonical_func_name}(states, seeds, params, history, d_history, h, data_out, history2):")
# Emit CSE definitions first so both values and directional derivatives can be reused.
if use_cse:
analyzer = SubexpressionAnalyzer()
cse_map = analyzer.analyze(equations)
sorted_cse = sorted(cse_map.items(), key=_get_cse_sort_index)
adv.analyzer = analyzer
adv.cse_map = cse_map
adv._str_cache = dict()
for expr_hash, var_name in sorted_cse:
expr_obj = analyzer.expr_objects[expr_hash]
expr_val: str
expr_dot: str
adv.in_cse_def = True
expr_val, expr_dot = adv.visit(expr_obj)
adv.in_cse_def = False
lines.append(f" {var_name} = {expr_val}")
if expr_dot != "0.0":
lines.append(f" {var_name}_d = {expr_dot}")
adv.cse_has_dot.add(var_name)
else:
pass
else:
pass
# The JVP values are written directly in CSC scatter order by the caller later on.
for i, eq in enumerate(equations):
_value_code: str
dot_code: str
_value_code, dot_code = adv.visit(eq)
if dot_code != "0.0":
lines.append(f" data_out[{i}] = {dot_code}")
else:
lines.append(f" data_out[{i}] = 0.0")
full_source = "\n".join(lines)
py_func = _compile_to_file(full_source, canonical_func_name)
signature_tpe = self.generate_signature(
kernel_tpe=EagerKernelKind.AutomaticDifferentiation,
n_variables=len(self.variables_objs),
n_parameters=len(self.parameters_objs),
nnz=len(equations),
with_history2=True,
)
_reset_symbolic_codegen_state(adv)
GENERATED_KERNEL_CACHE.set_entry("ad", cache_key, GeneratedKernelCacheEntry(py_func, signature_tpe))
return py_func, signature_tpe
[docs]
def compile_matrix_kernel(
self,
template_eq: Expr,
func_name: str,
template_vars: List[Var],
) -> Tuple[Callable, object]:
"""
Compile a vectorized matrix kernel into an in-place eager function.
:param template_eq: Canonical template equation for a structural group.
:type template_eq: Expr
:param func_name: Deterministic function name used in the source file.
:type func_name: str
:param template_vars: Ordered runtime variables of the structural group.
:type template_vars: List[Var]
:return: Pair ``(python_function, eager_signature)``.
:rtype: Tuple[Callable, object]
"""
col_map: Dict[str, int] = dict()
visitor: MatrixVectorizedVisitor
rhs_code: str
lines: List[str] = list()
py_func: Callable
signature_tpe: object
full_source: str
# The template variables define the deterministic position of each grouped runtime value.
for i, v in enumerate(template_vars):
mapped_name: str
if v.base_var is None:
mapped_name = v.name
else:
mapped_name = v.base_var.name
col_map[mapped_name] = i
cache_key = _build_matrix_codegen_cache_key(
template_eq,
self.var_map,
self.param_map,
type(self.strategy).__name__,
col_map,
len(self.variables_objs),
len(self.parameters_objs),
)
cache_entry = GENERATED_KERNEL_CACHE.get_entry("matrix", cache_key)
if cache_entry is None:
pass
else:
return cache_entry.get_python_function(), cache_entry.get_signature_tpe()
# The matrix visitor resolves vectorized gather operations through the grouped index matrix.
visitor = MatrixVectorizedVisitor(self.var_map, self.param_map, self.strategy, col_map)
rhs_code = visitor.visit(template_eq)
canonical_func_name = f"eager_matrix_{cache_key[:16]}"
# The eager matrix ABI writes directly into the caller-owned batch buffer.
lines.append(f"def {canonical_func_name}(states, params, history, d_history, h, indices, data_out, history2):")
lines.append(f" data_out[:indices.shape[0]] = {rhs_code}")
full_source = "\n".join(lines)
py_func = _compile_to_file(full_source, canonical_func_name)
signature_tpe = self.generate_signature(
kernel_tpe=EagerKernelKind.MatrixVectorized,
n_variables=len(self.variables_objs),
n_parameters=len(self.parameters_objs),
nnz=len(template_vars),
with_history2=True,
)
GENERATED_KERNEL_CACHE.set_entry("matrix", cache_key, GeneratedKernelCacheEntry(py_func, signature_tpe))
return py_func, signature_tpe
# ==============================================================================
# 7. RMS Native Compiler (Continuous Time & Sparse Jacobians)
# ==============================================================================
[docs]
class RMSCompiler(EquationCompiler):
"""
O(N) compiler for RMS (Root Mean Square) Continuous-Time Simulations.
This compiler generates highly optimized Right-Hand Side (RHS) vectors and
Sparse Jacobian matrices using a structural analysis approach. By leveraging
the official `expression2numba` translator, it ensures 100% mathematical
consistency with legacy systems while drastically reducing symbolic evaluation
and compilation times.
"""
__slots__ = ['dt_var', 'compiler_names_dict', 'diff_vars', 'v_params']
def __init__(self,
variables: List[Var],
diff_vars: List[Var],
v_params: List[Var],
c_params: List[Var],
dt_var: Var,
compiler_names_dict: Dict[int, str]) -> None:
"""
Initializes the RMS Compiler with the system's mathematical components.
Args:
variables (List[Var]): State and algebraic variables of the system.
diff_vars (List[Var]): Differential variables (dx/dt).
v_params (List[Var]): Variable (event-driven) parameters.
c_params (List[Var]): Constant parameters.
dt_var (Var): The symbolic variable representing the integration time step.
compiler_names_dict (Dict[int, str]): Dictionary mapping variable UIDs to
their executable string representations.
"""
super().__init__(variables=variables, parameters=c_params, method=DynamicIntegrationMethod.DaeBackEuler)
self.dt_var = dt_var
self.compiler_names_dict = compiler_names_dict
self.diff_vars = diff_vars
self.v_params = v_params
self.compiler_names_dict[dt_var.uid] = 'h'
[docs]
def compile_rhs(self, equations: List[Expr], func_name: str) -> Callable[[Vec, Vec, Vec, Vec], Vec]:
"""
Compiles the residual equations (RHS) into a fast, executable JIT function.
Args:
equations (List[Expr]): The list of symbolic expressions to evaluate.
func_name (str): The desired name for the compiled target function.
Returns:
Callable: An executable function computing the RHS residuals.
"""
# Keep legacy signature (`vrs`) while also exposing `vars` alias,
# so equations compiled with either naming convention work.
lines: List[str] = list()
lines.append(f"def {func_name}(vrs, diff, vprms, cprms):")
lines.append(" vars = vrs")
lines.append(f" out = np.zeros({len(equations)}, dtype=np.float64)")
for i, eq in enumerate(equations):
expr_str = expression2numba(eq, self.compiler_names_dict)
lines.append(f" out[{i}] = {expr_str}")
lines.append(" return out")
return _compile_to_file("\n".join(lines), func_name)
[docs]
def compile_sparse_jacobian(self, eqs: List[Expr],
wrt_vars: List[Var],
func_name: str) -> Callable[[Vec, Vec, Vec, Vec, float], sp.csc_matrix]:
"""
Compiles a sparse Jacobian evaluator using an O(N) structural extraction algorithm.
Args:
eqs (List[Expr]): The list of symbolic equations (F or G).
wrt_vars (List[Var]): The variables to differentiate with respect to (x or y).
func_name (str): The desired name for the compiled target function.
Returns:
Callable: A function that populates and returns a scipy.sparse.csc_matrix.
"""
wrt_map = {v.uid: (i, v) for i, v in enumerate(wrt_vars)}
triplets: List[Tuple[int, int, Expr]] = list()
for row, eq in enumerate(eqs):
for uid in _collect_candidate_wrt_uids(eq, wrt_map):
col, var = wrt_map[uid]
d_expr = eq.diff(var, dt=self.dt_var)
if not (isinstance(d_expr, Const) and d_expr.value == 0):
triplets.append((col, row, d_expr))
triplets.sort(key=_get_triplet_sort_key)
# Handle edge case: Empty Jacobian matrix (no dependencies found)
if not triplets:
J_empty = sp.csc_matrix((len(eqs), len(wrt_vars)))
return EmptySparseJacobianEvaluator(J_empty)
# Extract sorted coordinates and corresponding symbolic derivatives
cols_sorted = [t[0] for t in triplets]
rows_sorted = [t[1] for t in triplets]
d_exprs = [t[2] for t in triplets]
# Precompute CSC (Compressed Sparse Column) pointer arrays
nnz = len(cols_sorted)
indices = np.fromiter(rows_sorted, dtype=np.int32, count=nnz)
indptr = np.zeros(len(wrt_vars) + 1, dtype=np.int32)
for c in cols_sorted:
indptr[c + 1] += 1
np.cumsum(indptr, out=indptr)
# Initialize the sparse matrix template
J_template = sp.csc_matrix((np.zeros(nnz), indices, indptr), shape=(len(eqs), len(wrt_vars)))
# The Jacobian receives 'h' in its signature for solver compatibility,
# even though 'expression2numba' internally resolves 'dt' via vprms[...]
lines = [f"def {func_name}_filler(vrs, diff, vprms, cprms, h, data_out):"]
lines.append(" vars = vrs")
for i, expr in enumerate(d_exprs):
expr_str = expression2numba(expr, self.compiler_names_dict)
# in case of vectorization
# lines.append(f" data_out[{i}] = {expr_str}")
lines.append(f" data_out[{i}] = {expr_str}")
filler_fn = _compile_to_file("\n".join(lines), f"{func_name}_filler")
return SparseJacobianEvaluatorWrapper(filler_fn=filler_fn, matrix=J_template)
[docs]
def compile_event_params_fn(self, eqs: List[Expr], alias_names_dict: Dict[int, str],
EVENT_PARAMS_NAME: str, TIME_NAME: str,
func_name: str = "event_params_fn") -> Callable[[Vec, float, Vec], Vec]:
"""
Compiles event parameters equations into a fast, executable JIT function.
This is the equivalent of SymbolicParamsVector.
Args:
eqs (List[Expr]): The list of symbolic expressions for event parameters.
alias_names_dict (Dict[int, str]): Dictionary mapping var UIDs to alias names.
EVENT_PARAMS_NAME (str): Name of the event parameters array (e.g., "vprms").
TIME_NAME (str): Name of the time variable (e.g., "glob_time").
func_name (str): The desired name for the compiled target function.
Returns:
Callable: An executable function with signature (event_params, time, out) -> out.
"""
used_vars_count = defaultdict(int)
for eq in eqs:
for var in get_expression_vars(eq):
used_vars_count[var.uid] += 1
lines: List[str] = list()
lines.append(f"def {func_name}({EVENT_PARAMS_NAME}, {TIME_NAME}, out):")
final_names_dict: Dict[int, str] = self.compiler_names_dict.copy()
for uid, count in used_vars_count.items():
if count > 1:
final_names_dict[uid] = alias_names_dict[uid]
lines.append(f" {alias_names_dict[uid]} = {self.compiler_names_dict[uid]}")
else:
pass
for i, eq in enumerate(eqs):
# Event-parameter expressions may reuse the same runtime variable many
# times, especially for piecewise ramps built from time-dependent
# expressions. The compiled expression must therefore use the alias-
# aware name map prepared above, otherwise the generated code can
# evaluate a simplified step-like path instead of the intended ramp.
expr_str = expression2numba(eq, final_names_dict)
lines.append(f" out[{i}] = {expr_str}")
raw_fn = _compile_to_file("\n".join(lines), func_name)
return EventParameterFunctionWrapper(raw_fn=raw_fn, equation_count=len(eqs))
[docs]
def compile_derivative_fn(self, uid2idx_vars: Dict[int, int],
func_name: str = "derivative_fn") -> Callable[[Vec, Vec, Vec, float, Vec], Vec]:
"""
Compiles the derivative evaluation function for differential variables.
This is the equivalent of SymbolicDerivative.
Args:
uid2idx_vars (Dict[int, int]): Dictionary mapping var UIDs to their indices.
func_name (str): The desired name for the compiled target function.
Returns:
Callable: An executable function with signature (vrs, lagvars, lagdx, h, out) -> out.
"""
lines: List[str] = list()
lines.append(f"def {func_name}(vrs, lagvars, lagdx, h, out):")
uid2dx_expression = {}
for i, diff_var in enumerate(self.diff_vars):
base_var_uid = diff_var.base_var.uid
uid = diff_var.uid
if diff_var.diff_order == 1:
j = uid2idx_vars[base_var_uid]
rhs = f" (vrs[{j}] - lagvars[{j}]) / h"
else:
dx_expression = uid2dx_expression[diff_var.base_var.uid]
rhs = f" ({dx_expression} - lagdx[{i}]) / h"
lines.append(f" out[{i}] = {rhs}")
uid2dx_expression[uid] = rhs
lines.append(" return out")
raw_fn = _compile_to_file("\n".join(lines), func_name)
return DerivativeFunctionWrapper(raw_fn=raw_fn, diff_var_count=len(self.diff_vars))
[docs]
class RMSCompilerVec(EquationCompiler):
"""
O(N) compiler for RMS (Root Mean Square) Continuous-Time Simulations.
This compiler generates highly optimized Right-Hand Side (RHS) vectors and
Sparse Jacobian matrices using a structural analysis approach. By leveraging
the official `expression2numba` translator, it ensures 100% mathematical
consistency with legacy systems while drastically reducing symbolic evaluation
and compilation times.
"""
__slots__ = ['dt_var', 'compiler_names_dict', 'diff_vars', 'v_params']
def __init__(self,
variables: List[Var],
diff_vars: List[Var],
v_params: List[Var],
c_params: List[Var],
dt_var: Var,
compiler_names_dict: Dict[int, str]) -> None:
"""
Initializes the RMS Compiler with the system's mathematical components.
Args:
variables (List[Var]): State and algebraic variables of the system.
diff_vars (List[Var]): Differential variables (dx/dt).
v_params (List[Var]): Variable (event-driven) parameters.
c_params (List[Var]): Constant parameters.
dt_var (Var): The symbolic variable representing the integration time step.
compiler_names_dict (Dict[int, str]): Dictionary mapping variable UIDs to
their executable string representations.
"""
super().__init__(variables=variables, parameters=c_params, method=DynamicIntegrationMethod.DaeBackEuler)
self.dt_var = dt_var
self.compiler_names_dict = compiler_names_dict
self.diff_vars = diff_vars
self.v_params = v_params
self.compiler_names_dict[dt_var.uid] = 'h'
[docs]
def compile_rhs(self, equations: List[Expr], func_name: str) -> Callable[[Vec, Vec, Vec, Vec], Vec]:
"""
Compiles the residual equations (RHS) into a fast, executable JIT function.
Args:
equations (List[Expr]): The list of symbolic expressions to evaluate.
func_name (str): The desired name for the compiled target function.
Returns:
Callable: An executable function computing the RHS residuals.
"""
# Keep legacy signature (`vrs`) while also exposing `vars` alias,
# so equations compiled with either naming convention work.
lines: List[str] = list()
lines.append(f"def {func_name}(vrs, diff, vprms, cprms):")
lines.append(" vars = vrs")
lines.append(" n_models = vars.shape[1]")
lines.append(f" out = np.zeros(({len(equations)}, n_models), dtype=np.float64)")
for i, eq in enumerate(equations):
expr_str = expression2numba(eq, self.compiler_names_dict)
lines.append(f" out[{i}, :] = {expr_str}")
lines.append(" return out")
return _compile_to_file("\n".join(lines), func_name)
[docs]
def compile_sparse_jacobian(self, eqs: List[Expr],
wrt_vars: List[Var],
func_name: str) -> Callable[[Vec, Vec, Vec, Vec, float], np.ndarray]:
"""
Compiles a vectorized sparse Jacobian evaluator.
Args:
eqs (List[Expr]): The list of symbolic equations (F or G).
wrt_vars (List[Var]): The variables to differentiate with respect to (x or y).
func_name (str): The desired name for the compiled target function.
Returns:
Callable: A function that returns a 2D data array (nnz, n_instances).
"""
wrt_map = {v.uid: (i, v) for i, v in enumerate(wrt_vars)}
triplets: List[Tuple[int, int, Expr]] = list()
for row, eq in enumerate(eqs):
for uid in _collect_candidate_wrt_uids(eq, wrt_map):
col, var = wrt_map[uid]
d_expr = eq.diff(var, dt=self.dt_var)
if not (isinstance(d_expr, Const) and d_expr.value == 0):
triplets.append((col, row, d_expr))
triplets.sort(key=_get_triplet_sort_key)
# Handle edge case: Empty Jacobian matrix (no dependencies found)
if not triplets:
return EmptyVecSparseJacobianEvaluator(len(eqs), len(wrt_vars))
d_exprs = [t[2] for t in triplets]
nnz = len(d_exprs)
# The generated code fills a 2D data_out array (nnz, n_models) using
# numpy broadcasting over the instance dimension.
lines = [f"def {func_name}_filler(vrs, diff, vprms, cprms, h, data_out):"]
lines.append(" vars = vrs")
for i, expr in enumerate(d_exprs):
expr_str = expression2numba(expr, self.compiler_names_dict)
lines.append(f" data_out[{i}, :] = {expr_str}")
filler_fn = _compile_to_file("\n".join(lines), f"{func_name}_filler")
rows_sorted = [t[1] for t in triplets]
cols_sorted = [t[0] for t in triplets]
indices = np.fromiter(rows_sorted, dtype=np.int32, count=nnz)
indptr = np.zeros(len(wrt_vars) + 1, dtype=np.int32)
for c in cols_sorted:
indptr[c + 1] += 1
np.cumsum(indptr, out=indptr)
return SparseJacobianEvaluatorVecWrapper(
filler_fn=filler_fn,
nnz=nnz,
n_rows=len(eqs),
n_cols=len(wrt_vars),
rows=rows_sorted,
cols=cols_sorted,
indices=indices,
indptr=indptr,
)
[docs]
def compile_event_params_fn(self, eqs: List[Expr], alias_names_dict: Dict[int, str],
EVENT_PARAMS_NAME: str, TIME_NAME: str,
func_name: str = "event_params_fn") -> Callable[[Vec, float, Vec], Vec]:
"""
Compiles event parameters equations into a fast, executable JIT function.
This is the equivalent of SymbolicParamsVector.
Args:
eqs (List[Expr]): The list of symbolic expressions for event parameters.
alias_names_dict (Dict[int, str]): Dictionary mapping var UIDs to alias names.
EVENT_PARAMS_NAME (str): Name of the event parameters array (e.g., "vprms").
TIME_NAME (str): Name of the time variable (e.g., "glob_time").
func_name (str): The desired name for the compiled target function.
Returns:
Callable: An executable function with signature (event_params, time, out) -> out.
"""
used_vars_count = defaultdict(int)
for eq in eqs:
for var in get_expression_vars(eq):
used_vars_count[var.uid] += 1
lines: List[str] = list()
lines.append(f"def {func_name}({EVENT_PARAMS_NAME}, {TIME_NAME}, out):")
final_names_dict: Dict[int, str] = self.compiler_names_dict.copy()
for uid, count in used_vars_count.items():
if count > 1:
final_names_dict[uid] = alias_names_dict[uid]
lines.append(f" {alias_names_dict[uid]} = {self.compiler_names_dict[uid]}")
else:
pass
for i, eq in enumerate(eqs):
# Event-parameter expressions may reuse the same runtime variable many
# times, especially for piecewise ramps built from time-dependent
# expressions. The compiled expression must therefore use the alias-
# aware name map prepared above, otherwise the generated code can
# evaluate a simplified step-like path instead of the intended ramp.
expr_str = expression2numba(eq, final_names_dict)
lines.append(f" out[{i}] = {expr_str}")
raw_fn = _compile_to_file("\n".join(lines), func_name)
return EventParameterFunctionWrapper(raw_fn=raw_fn, equation_count=len(eqs))
[docs]
def compile_derivative_fn(self, uid2idx_vars: Dict[int, int],
func_name: str = "derivative_fn") -> Callable[[Vec, Vec, Vec, float, Vec], Vec]:
"""
Compiles the derivative evaluation function for differential variables.
This is the equivalent of SymbolicDerivative.
Args:
uid2idx_vars (Dict[int, int]): Dictionary mapping var UIDs to their indices.
func_name (str): The desired name for the compiled target function.
Returns:
Callable: An executable function with signature (vrs, lagvars, lagdx, h, out) -> out.
"""
lines: List[str] = list()
lines.append(f"def {func_name}(vrs, lagvars, lagdx, h, out):")
uid2dx_expression = {}
for i, diff_var in enumerate(self.diff_vars):
base_var_uid = diff_var.base_var.uid
uid = diff_var.uid
if diff_var.diff_order == 1:
j = uid2idx_vars[base_var_uid]
rhs = f" (vrs[{j}] - lagvars[{j}]) / h"
else:
dx_expression = uid2dx_expression[diff_var.base_var.uid]
rhs = f" ({dx_expression} - lagdx[{i}]) / h"
lines.append(f" out[{i}] = {rhs}")
uid2dx_expression[uid] = rhs
lines.append(" return out")
raw_fn = _compile_to_file("\n".join(lines), func_name)
return DerivativeFunctionWrapper(raw_fn=raw_fn, diff_var_count=len(self.diff_vars))