# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
# SPDX-License-Identifier: MPL-2.0
import numpy as np
import numba as nb
import scipy.sparse as sp
import time
from VeraGridEngine.Utils.Symbolic.symbolic import Var, Expr, BinOp, UnOp, Func
from VeraGridEngine.Simulations.EMT.problems.emt_problem_template import (
EmtBoundaryUpdateProtocol,
EmtProblemTemplate,
get_solver_forced_event_time,
is_problem_owned_boundary_updater,
resolve_solver_boundary_updater,
)
from VeraGridEngine.enumerations import DynamicIntegrationMethod
from VeraGridEngine.Utils.Symbolic.jit_compiler import EagerEquationCompiler, EquationCompiler
from VeraGridEngine.Utils.Symbolic.diagnostic import (with_newton_diagnostics, NewtonDiagnosticsConfig,
NewtonSolveContext, dense_lstsq_fallback,
sparse_lsqr_fallback, maybe_check_index1,
maybe_apply_backtracking)
from VeraGridEngine.basic_structures import Vec, Mat
from typing import Tuple, Callable, List, Any, Dict, Set, Union
from scipy.sparse import csc_matrix
def _safe_njit(py_func: Callable[..., Any], fastmath: bool = True, cache: bool = True, signature: Any = None) -> Callable[..., Any]:
"""
Safely wraps a python function with Numba's njit compiler.
The caller must explicitly pass cache=False if the function is generated via strings.
:param py_func: The python function to compile.
:type py_func: Callable[..., Any]
:param fastmath: Flag to enable fastmath operations, defaults to True.
:type fastmath: bool
:param cache: Flag to enable caching of the compiled function, defaults to True.
:type cache: bool
:param signature: Optional Numba signature.
:type signature: Any
:return: The Numba JIT compiled function.
:rtype: Callable[..., Any]
"""
if signature is not None:
compiled_kernel: Callable[..., Any] = nb.njit(signature, fastmath=fastmath, cache=cache)(py_func)
else:
compiled_kernel = nb.njit(fastmath=fastmath, cache=cache)(py_func)
return compiled_kernel
# ==============================================================================
# Sparse Forward-Mode AD Jacobian with Graph Coloring (JVP-based)
# ==============================================================================
@nb.njit(cache=True, fastmath=True)
def _scatter_color_jvp_to_csc_data(jvp: np.ndarray,
data: np.ndarray,
color_ptr: np.ndarray,
col_ptr: np.ndarray,
row_idx: np.ndarray,
data_idx: np.ndarray,
color_id: int) -> None:
"""
Scatter a single JVP evaluation into the CSC data array for all columns in a given color.
"""
k0: int = int(color_ptr[color_id])
k1: int = int(color_ptr[color_id + 1])
for k in range(k0, k1):
p0: int = int(col_ptr[k])
p1: int = int(col_ptr[k + 1])
for p in range(p0, p1):
data[data_idx[p]] = jvp[row_idx[p]]
def _compile_master_jacobian_kernel(ad_kernel: Callable[..., Any], n_colors: int) -> Callable[..., Any]:
"""
Build the eager sparse Jacobian dispatcher.
:param ad_kernel: Generic AD kernel reused across graph colors.
:type ad_kernel: Callable[..., Any]
:param n_colors: Number of graph colors.
:type n_colors: int
:return: Jacobian dispatcher.
:rtype: Callable[..., Any]
"""
return SparseAdMasterJacobianDispatcher(ad_kernel=ad_kernel, n_colors=n_colors)
[docs]
class SparseAdMasterJacobianDispatcher:
"""
Callable dispatcher that executes one generic AD kernel across all colors.
"""
__slots__ = ("_ad_kernel", "_n_colors")
def __init__(self, ad_kernel: Callable[..., Any], n_colors: int) -> None:
"""
Store the generic AD kernel and the number of graph colors.
:param ad_kernel: Generic AD kernel reused across graph colors.
:param n_colors: Number of graph colors.
:return: None.
"""
self._ad_kernel: Callable[..., Any] = ad_kernel
self._n_colors: int = n_colors
def __call__(self,
states: Vec,
seed_matrix: Mat,
params: Vec,
history: Vec,
d_history: Vec,
h: float,
history2: Vec,
data: Vec,
color_ptr: np.ndarray,
col_ptr: np.ndarray,
row_idx: np.ndarray,
data_idx: np.ndarray,
work_jvp: Mat) -> None:
"""
Evaluate all color JVPs and scatter them into CSC data storage.
:param states: Current Newton iterate.
:param seed_matrix: Coloring seed matrix.
:param params: Full parameter vector.
:param history: Previous accepted state.
:param d_history: Previous derivative vector.
:param h: Effective time step.
:param history2: Secondary history vector.
:param data: CSC numeric data buffer.
:param color_ptr: Color-to-column pointer array.
:param col_ptr: Column-to-scatter pointer array.
:param row_idx: Row indices used by the scatter map.
:param data_idx: CSC data positions used by the scatter map.
:param work_jvp: Reusable JVP work buffer.
:return: None.
"""
color_index: int = 0
while color_index < self._n_colors:
local_jvp: Vec = work_jvp[color_index]
self._ad_kernel(states, seed_matrix[color_index], params, history, d_history, h, local_jvp, history2)
_scatter_color_jvp_to_csc_data(local_jvp, data, color_ptr, col_ptr, row_idx, data_idx, color_index)
color_index += 1
[docs]
def greedy_color_columns(col_rows: List[List[int]], n_rows: int) -> Tuple[np.ndarray, int]:
"""
Computes a greedy coloring of the column dependency graph to minimize AD sweeps.
:param col_rows: List of row indices where each column has non-zeros.
:type col_rows: List[List[int]]
:param n_rows: Total number of rows in the matrix.
:type n_rows: int
:return: An array of color IDs per column and the total number of colors.
:rtype: Tuple[np.ndarray, int]
"""
n_cols: int = len(col_rows)
row_cols: List[List[int]] = [list() for _ in range(n_rows)]
for j in range(n_cols):
for r in col_rows[j]:
row_cols[r].append(j)
adj: List[Set[int]] = [set() for _ in range(n_cols)]
for cols in row_cols:
m: int = len(cols)
for a in range(m):
ja: int = cols[a]
for b in range(a + 1, m):
jb: int = cols[b]
if jb != ja:
adj[ja].add(jb)
adj[jb].add(ja)
else:
pass
degrees: np.ndarray = np.array([len(adj[j]) for j in range(n_cols)], dtype=np.int32)
order: List[int] = list(np.argsort(-degrees))
colors: np.ndarray = -np.ones(n_cols, dtype=np.int32)
max_color: int = -1
used: np.ndarray = np.zeros(n_cols, dtype=np.bool_)
for j in order:
used[:] = False
for nbj in adj[j]:
c: int = int(colors[nbj])
if c >= 0:
used[c] = True
else:
pass
c_iter: int = 0
while c_iter < n_cols and used[c_iter]:
c_iter += 1
colors[j] = c_iter
if c_iter > max_color:
max_color = c_iter
else:
pass
return colors, int(max_color + 1)
[docs]
class BoundaryUpdaterInterface:
[docs]
def update(self, t: float, x_prev: Vec, full_params: Vec) -> None:
raise NotImplementedError
[docs]
def get_next_forced_event_time(self, t_prev: float, t_target: float) -> float | None:
raise NotImplementedError
[docs]
class Predictor:
"""
Computes an explicit predictor for the Newton initial guess x_iter.
"""
__slots__ = ['n_states']
def __init__(self,
n_states: int) -> None:
self.n_states = n_states
[docs]
def predict(self,
x_iter: Vec,
x_prev: Vec,
dx_prev: Vec,
h: float,
pred_method: Union[DynamicIntegrationMethod|None] = DynamicIntegrationMethod.DaeBackEuler) -> Vec:
"""
Apply predictor in-place and return x_iter.
:param x_iter: Array to be written with the predictor guess (in-place).
:param x_prev: Previous full variable vector at time n.
:param dx_prev: Previous derivative vector (only first n_states are meaningful).
:param h: Time step.
:param pred_method: Integration method enum/value at the prediction step
:return:
"""
# Default: no-op predictor
x_iter[:] = x_prev
# Dispatch per method
if pred_method == DynamicIntegrationMethod.OdeEuler:
return self._predict_euler_state(x_iter, x_prev, dx_prev, h)
elif pred_method == DynamicIntegrationMethod.DaeBDF2:
raise NotImplementedError(f"predictor method {pred_method} ")
# Otherwise keep default x_iter = x_prev
return x_iter
# ---------------------------------------------------------------------
# Predictor implementations
# ---------------------------------------------------------------------
def _predict_euler_state(self,
x_iter: Vec,
x_prev: Vec,
dx_prev: Vec,
h: float) -> Vec:
"""
Explicit Euler predictor for the *state* subset:
x_{n+1}^0 = x_n + h * dx_n (states only)
Algebraics remain at x_prev by default (already set in caller).
"""
ns = self.n_states
x_iter[:ns] = x_prev[:ns] + h * dx_prev[:ns]
return x_iter
[docs]
def fill_full_parameter_buffer(runtime_params: Vec, static_params: Vec, full_params_out: Vec) -> None:
"""
Write runtime and static parameters into one preallocated full buffer.
:param runtime_params: Runtime parameter slice.
:type runtime_params: Vec
:param static_params: Static parameter slice.
:type static_params: Vec
:param full_params_out: Destination full parameter vector.
:type full_params_out: Vec
:return: None.
:rtype: None
"""
n_runtime: int = runtime_params.shape[0]
n_static: int = static_params.shape[0]
# The runtime portion must be refreshed every local substep because events
# and boundary updates can change it in place.
full_params_out[:n_runtime] = runtime_params
# The static portion is copied after the runtime slice so the solver sees a
# stable [runtime | static] layout in every backend call.
full_params_out[n_runtime:n_runtime + n_static] = static_params
[docs]
def evaluate_batched_residual(
kernel_list: List[Callable[..., Any]],
x_iter: Vec,
full_params: Vec,
x_prev: Vec,
dx_prev: Vec,
h_eff: float,
x_prev2: Vec,
residual_out: Vec,
) -> float:
"""
Evaluate all residual batches into a caller-owned buffer.
:param kernel_list: Batched residual kernels.
:type kernel_list: List[Callable[..., Any]]
:param x_iter: Current Newton iterate.
:type x_iter: Vec
:param full_params: Full parameter vector.
:type full_params: Vec
:param x_prev: Previous accepted state.
:type x_prev: Vec
:param dx_prev: Previous derivative vector.
:type dx_prev: Vec
:param h_eff: Effective local time step.
:type h_eff: float
:param x_prev2: Secondary state history.
:type x_prev2: Vec
:param residual_out: Destination residual buffer.
:type residual_out: Vec
:return: Residual infinity norm.
:rtype: float
"""
residual_out[:] = 0.0
# The residual is assembled batch by batch because the symbolic compiler
# splits large systems to keep code generation robust.
for kernel in kernel_list:
kernel(x_iter, full_params, x_prev, dx_prev, h_eff, residual_out, x_prev2)
return float(np.linalg.norm(residual_out, np.inf))
[docs]
class AdBacktrackingResidualEvaluator:
"""
Wrapper that evaluates AD residual batches for line-search backtracking.
"""
__slots__ = ["_kernel_list", "_full_params", "_x_prev", "_dx_prev", "_h_eff", "_x_prev2"]
def __init__(self) -> None:
"""
Build one empty backtracking evaluator.
:return: None.
:rtype: None
"""
self._kernel_list: List[Callable[..., Any]] | None = None
self._full_params: Vec | None = None
self._x_prev: Vec | None = None
self._dx_prev: Vec | None = None
self._h_eff: float = 0.0
self._x_prev2: Vec | None = None
[docs]
def set_context(
self,
kernel_list: List[Callable[..., Any]],
full_params: Vec,
x_prev: Vec,
dx_prev: Vec,
h_eff: float,
x_prev2: Vec,
) -> None:
"""
Store the residual-evaluation context of the current Newton step.
:param kernel_list: Batched residual kernels.
:type kernel_list: List[Callable[..., Any]]
:param full_params: Full parameter vector.
:type full_params: Vec
:param x_prev: Previous accepted state.
:type x_prev: Vec
:param dx_prev: Previous derivative vector.
:type dx_prev: Vec
:param h_eff: Effective local time step.
:type h_eff: float
:param x_prev2: Secondary state history.
:type x_prev2: Vec
:return: None.
:rtype: None
"""
self._kernel_list = kernel_list
self._full_params = full_params
self._x_prev = x_prev
self._dx_prev = dx_prev
self._h_eff = float(h_eff)
self._x_prev2 = x_prev2
[docs]
def evaluate(self, candidate_x: Vec, out_res: Vec) -> float:
"""
Evaluate one trial Newton iterate during backtracking.
:param candidate_x: Trial iterate.
:type candidate_x: Vec
:param out_res: Destination residual buffer.
:type out_res: Vec
:return: Residual infinity norm.
:rtype: float
"""
if self._kernel_list is None:
return np.inf
else:
pass
if self._full_params is None or self._x_prev is None or self._dx_prev is None or self._x_prev2 is None:
return np.inf
else:
pass
return evaluate_batched_residual(
kernel_list=self._kernel_list,
x_iter=candidate_x,
full_params=self._full_params,
x_prev=self._x_prev,
dx_prev=self._dx_prev,
h_eff=self._h_eff,
x_prev2=self._x_prev2,
residual_out=out_res,
)
[docs]
class SparseADJacobian:
"""
Sparse Jacobian evaluator with DEBUG TIMING.
"""
__slots__ = [
'equations', 'variables', 'parameters', 'method', 'use_cse', 'dtype',
'n_rows', 'n_cols', 'var_map', 'param_map', 'col_rows', 'J',
'colors', 'n_colors', 'color_groups', 'color_ptr', 'color_cols',
'col_ptr', 'row_idx', 'data_idx', '_compiler', '_ad_kernel', '_master_dispatcher',
'_seed_matrix', '_jvp_work_buffer', '_data_buffer'
]
def __init__(self,
equations: List[Any],
variables: List[Any],
parameters: List[Any],
method: DynamicIntegrationMethod,
use_cse: bool = True,
dtype: Any=np.float64)-> None:
t_start = time.perf_counter()
print_enabled: bool = False
self.equations = equations
self.variables = variables
self.parameters = parameters
self.method = method
self.use_cse = use_cse
self.dtype = dtype
self.n_rows = len(equations)
self.n_cols = len(variables)
self.var_map: Dict[int, int] = dict()
for i, v in enumerate(variables):
self.var_map[v.uid] = i
self.param_map: Dict[int, int] = dict()
for i, p in enumerate(parameters):
self.param_map[p.uid] = i
t_maps: float = time.perf_counter()
if print_enabled:
print(f" [AD-Debug] Maps built in {t_maps - t_start:.4f}s")
# 1) Sparsity Detection
col_rows: List[List[int]] = [list() for _ in range(self.n_cols)]
count_edges: int = 0
for r, eq in enumerate(equations):
stack: List[Any] = list()
stack.append(eq)
visited: Set[int] = set()
uids: Set[int] = set()
while stack:
n: Any = stack.pop()
if id(n) in visited:
pass
else:
visited.add(id(n))
match n:
case Var(uid=u, base_var=bv):
uids.add(u)
if bv is not None:
uids.add(bv.uid)
else:
pass
case BinOp(left=l, right=r_node):
stack.append(l)
stack.append(r_node)
case UnOp(operand=op):
stack.append(op)
case Func(arg=a):
stack.append(a)
case _:
pass
for uid in uids:
j: int | None = self.var_map.get(uid, None)
if j is not None:
col_rows[j].append(r)
count_edges += 1
else:
pass
self.col_rows = [sorted(set(rows)) for rows in col_rows]
t_sparsity = time.perf_counter()
if print_enabled:
print(f" [AD-Debug] Sparsity Analysis done in {t_sparsity - t_maps:.4f}s. Edges found: {count_edges}")
# 2) CSC Structure
indptr = np.zeros(self.n_cols + 1, dtype=np.int32)
nnz = 0
for j in range(self.n_cols):
nnz += len(self.col_rows[j])
indptr[j + 1] = nnz
indices = np.empty(nnz, dtype=np.int32)
data = np.zeros(nnz, dtype=self.dtype)
k = 0
for j in range(self.n_cols):
for r in self.col_rows[j]:
indices[k] = r
k += 1
self.J = csc_matrix((data, indices, indptr), shape=(self.n_rows, self.n_cols))
t_csc = time.perf_counter()
if print_enabled:
print(f" [AD-Debug] CSC Matrix built in {t_csc - t_sparsity:.4f}s")
# 3) Coloring
self.colors, self.n_colors = greedy_color_columns(self.col_rows, self.n_rows)
groups: List[List[int]] = [list() for _ in range(self.n_colors)]
for j in range(self.n_cols):
groups[self.colors[j]].append(j)
self.color_groups = groups
t_color = time.perf_counter()
if print_enabled:
print(f" [AD-Debug] Graph Coloring done in {t_color - t_csc:.4f}s. Colors: {self.n_colors}")
# 4) Precompute scatter map
color_cols: List[int] = list()
color_ptr = np.zeros(self.n_colors + 1, dtype=np.int32)
for c in range(self.n_colors):
color_ptr[c] = len(color_cols)
color_cols.extend(self.color_groups[c])
color_ptr[self.n_colors] = len(color_cols)
col_ptr = np.zeros(len(color_cols) + 1, dtype=np.int32)
total_map = 0
for kk, j in enumerate(color_cols):
total_map += len(self.col_rows[j])
col_ptr[kk + 1] = total_map
row_idx = np.empty(total_map, dtype=np.int32)
data_idx = np.empty(total_map, dtype=np.int32)
pos = 0
indptr_arr = self.J.indptr
for kk, j in enumerate(color_cols):
base = indptr_arr[j]
rows = self.col_rows[j]
for local, r in enumerate(rows):
row_idx[pos] = r
data_idx[pos] = base + local
pos += 1
self.color_ptr = color_ptr
self.color_cols = np.asarray(color_cols, dtype=np.int32)
self.col_ptr = col_ptr
self.row_idx = row_idx
self.data_idx = data_idx
t_scatter = time.perf_counter()
if print_enabled:
print(f" [AD-Debug] Scatter Map built in {t_scatter - t_color:.4f}s")
# 5) Compile Optimization: one generic eager AD kernel reused across colors.
self._compiler = EagerEquationCompiler(variables=variables, parameters=parameters, method=method)
if print_enabled:
print(f" [AD-Opt] Compiling one generic AD kernel for {self.n_colors} colors...")
t_comp_start = time.perf_counter()
k_name = f"ad_step_generic_{self.method.name}_{self.n_rows}_{self.n_cols}"
py_ad, signature_tpe = self._compiler.compile_ad_kernel(
equations,
func_name=k_name,
use_cse=use_cse,
active_indices=None,
)
self._ad_kernel = _safe_njit(py_ad, cache=True, fastmath=True, signature=signature_tpe)
self._master_dispatcher = _compile_master_jacobian_kernel(self._ad_kernel, self.n_colors)
self._seed_matrix = np.zeros((self.n_colors, self.n_cols), dtype=self.dtype)
for c in range(self.n_colors):
self._seed_matrix[c, self.color_groups[c]] = 1.0
self._jvp_work_buffer = np.zeros((self.n_colors, self.n_rows), dtype=self.dtype)
self._data_buffer = self.J.data
t_comp_end = time.perf_counter()
if print_enabled:
print(f" [AD-Opt] Generic kernel compiled in {t_comp_end - t_comp_start:.4f}s")
print(f" [AD-Debug] Total Init Time: {t_comp_end - t_start:.4f}s")
def __call__(self,
states: np.ndarray,
params: np.ndarray,
history: np.ndarray,
d_history: np.ndarray,
h: float,
history2: Union[np.ndarray | None] = None) -> csc_matrix:
"""
Evaluates the AD Jacobian.
:param states: Current states.
:type states: np.ndarray
:param params: Current parameters.
:type params: np.ndarray
:param history: State history.
:type history: np.ndarray
:param d_history: Derivative history.
:type d_history: np.ndarray
:param h: Time step.
:type h: float
:param history2: Secondary state history.
:type history2: np.ndarray | None
:return: Evaluated sparse matrix.
:rtype: csc_matrix
"""
self._master_dispatcher(
states,
self._seed_matrix,
params,
history,
d_history,
h,
history2,
self._data_buffer,
self.color_ptr,
self.col_ptr,
self.row_idx,
self.data_idx,
self._jvp_work_buffer,
)
return self.J
[docs]
class JitAdSolver:
__slots__ = [
'problem', 't0', 't_end', 'h', 'method', 'pred_method', 'dense_threshold', 'verbose', 'newton_max_iter',
'steps', 't', 'y', 'dy', 'jit_kernels_ad', 'jit_jacobian_ad',
'state_vars', 'algebraic_vars', 'state_eqs', 'algebraic_eqs', '_newton_diag_config',
'_predictor', '_runtime_param_count', '_static_parameter_buffer', '_full_parameter_buffer',
'_residual_buffer', '_trial_state_buffer', '_trial_residual_buffer',
'_backend_build_stats', '_last_runtime_stats', '_last_sim_loop_time'
]
def __init__(self,
problem: EmtProblemTemplate,
t0: float,
t_end: float,
h: float,
method: DynamicIntegrationMethod = DynamicIntegrationMethod.DaeTrapezoidal,
pred_method: DynamicIntegrationMethod = None,
dense_threshold: int = 100,
verbose: bool = False,
newton_max_iter: int = 15,
newton_diag_config: NewtonDiagnosticsConfig | None = None)-> None:
"""
Initializes the JIT AD Solver.
:param problem: The DAE problem definition.
:type problem: EmtProblemTemplate
:param t0: Initial time.
:type t0: float
:param t_end: End time.
:type t_end: float
:param h: Time step.
:type h: float
:param method: DynamicIntegrationMethod (DaeTrapezoidal, DaeBackEuler, DaeBDF2).
:type method: DynamicIntegrationMethod
:param pred_method: DynamicIntegrationMethod used in the predictor step if method is explicit.
:type pred_method: DynamicIntegrationMethod
:param dense_threshold: Threshold to switch between dense and sparse linear solvers.
:type dense_threshold: int
:param verbose: Print compilation and simulation timings.
:type verbose: bool
:param newton_max_iter: Maximum Newton iterations per local EMT substep.
:type newton_max_iter: int
"""
self.problem = problem
self.t0 = t0
self.h = h
self.method = method
self.pred_method = pred_method
self.dense_threshold = dense_threshold
self.verbose = verbose
self.newton_max_iter: int = int(newton_max_iter)
self.t_end = t_end
self._newton_diag_config = newton_diag_config or NewtonDiagnosticsConfig(
compute_dense_cond=False,
enable_fallback=False,
enable_index1_check=False,
enable_backtracking=False,
)
self.steps = int(np.ceil((t_end - t0) / h))
self.t: Vec = np.empty(self.steps + 1)
self.y: Mat = np.empty((self.steps + 1, self.problem.get_all_vars_number()))
self.dy: Mat = np.empty((self.steps + 1, self.problem.get_diff_var_number()))
# Solver specific caches
self.jit_kernels_ad: Dict[DynamicIntegrationMethod, List[Callable]] = dict()
self.jit_jacobian_ad: Dict[DynamicIntegrationMethod, SparseADJacobian] = dict()
# Cache local of variables to avoid calling getters multiple times in init
self.state_vars = self.problem.get_state_vars()
self.algebraic_vars = self.problem.get_algebraic_vars()
self.state_eqs = self.problem.get_state_eqs()
self.algebraic_eqs = self.problem.get_algebraic_eqs()
self._predictor = Predictor(n_states=self.problem.get_states_number())
self._runtime_param_count: int = self.problem.get_variable_parameter_number()
self._static_parameter_buffer: Vec = np.asarray(
[float(constant.value) for constant in self.problem.get_parameters_values()],
dtype=np.float64,
)
self._full_parameter_buffer: Vec = np.zeros(
self._runtime_param_count + self._static_parameter_buffer.shape[0],
dtype=np.float64,
)
self._residual_buffer: Vec = np.zeros(self.problem.get_all_vars_number(), dtype=np.float64)
self._trial_state_buffer: Vec = np.zeros(self.problem.get_all_vars_number(), dtype=np.float64)
self._trial_residual_buffer: Vec = np.zeros(self.problem.get_all_vars_number(), dtype=np.float64)
self._backend_build_stats: Dict[str, float] = dict(
residual_compile_s=0.0,
residual_batches=0.0,
jacobian_compile_s=0.0,
jacobian_builds=0.0,
total_s=0.0,
)
self._last_runtime_stats: Dict[str, float] = dict()
self._last_sim_loop_time: float = 0.0
[docs]
def build_jit_ad(self, only_jacobian: bool = False)-> None:
"""
Compiles the residual kernel using batching and prepares the Sparse AD Jacobian.
:param only_jacobian: If True, skips residual compilation (used by Vectorized Engine fallback).
:type only_jacobian: bool
:rtype: None
"""
if self.method in self.jit_kernels_ad and not only_jacobian:
return
else:
pass
t_total_start: float = time.perf_counter()
if self.verbose:
print(
f"--- [JIT-AD] Compiling {'Jacobian ONLY' if only_jacobian else 'Implicit Kernel + AD Jacobian'} ({self.method}) ---")
else:
pass
sorted_vars: List[Any] = list()
sorted_vars.extend(self.state_vars)
sorted_vars.extend(self.algebraic_vars)
all_params: List[Any] = list()
all_params.extend(self.problem.get_variable_parameters())
all_params.extend(self.problem.get_constant_parameters())
equations: List[Expr] = list()
for i, rhs in enumerate(self.state_eqs):
sv = self.state_vars[i]
d_term = Var(name=f"d_{sv.name}", base_var=sv)
equations.append(d_term - rhs)
equations.extend(self.algebraic_eqs)
# --- PART 1: RESIDUALS (Batched) ---
if not only_jacobian and self.method not in self.jit_kernels_ad:
compiler = EquationCompiler(variables=sorted_vars, parameters=all_params, method=self.method)
BATCH_SIZE = 100
n_eqs = len(equations)
batched_kernels: List[Callable] = list()
if self.verbose:
print(f" [JIT-AD] Splitting {n_eqs} residual eqs into {int(np.ceil(n_eqs / BATCH_SIZE))} batches...")
t0 = time.perf_counter()
for i in range(0, n_eqs, BATCH_SIZE):
chunk = equations[i: i + BATCH_SIZE]
fname = f"jit_step_{self.method}_ad_part_{i}"
py_func = compiler.compile(chunk, func_name=fname, use_cse=True, offset=i, inplace=True)
jit_func = _safe_njit(py_func, cache=True, fastmath=True)
batched_kernels.append(jit_func)
if self.verbose:
print(f" [JIT-AD] Residual Batches compiled in {time.perf_counter() - t0:.4f}s")
else:
pass
self.jit_kernels_ad[self.method] = batched_kernels
residual_elapsed_s: float = time.perf_counter() - t0
self._backend_build_stats['residual_compile_s'] += residual_elapsed_s
self._backend_build_stats['residual_batches'] += float(len(batched_kernels))
self._backend_build_stats['total_s'] += residual_elapsed_s
else:
pass
# --- PART 2: JACOBIAN (O(N) setup, needed for normal and Vectorized fallback) ---
if self.method not in self.jit_jacobian_ad:
jacobian_t0: float = time.perf_counter()
jac = SparseADJacobian(equations=equations, variables=sorted_vars, parameters=all_params,
method=self.method, use_cse=True)
self.jit_jacobian_ad[self.method] = jac
jacobian_elapsed_s: float = time.perf_counter() - jacobian_t0
self._backend_build_stats['jacobian_compile_s'] += jacobian_elapsed_s
self._backend_build_stats['jacobian_builds'] += 1.0
self._backend_build_stats['total_s'] += jacobian_elapsed_s
else:
pass
total_elapsed_s: float = time.perf_counter() - t_total_start
self._backend_build_stats['total_s'] = max(self._backend_build_stats['total_s'], total_elapsed_s)
[docs]
def simulate(self,
x0: Union[Vec | None] = None,
dx0: Union[Vec | None] = None,
params0: Union[Vec | None] = None,
boundary_updater: EmtBoundaryUpdateProtocol | None = None) -> Tuple[Vec, Mat, Mat, bool, bool]:
"""
Main JIT simulation loop using the Automatic Differentiation (AD) backend.
:param x0: Initial state vector.
:param dx0: Initial derivative vector.
:param params0: Initial event parameters.
:param boundary_updater: Optional override for the problem boundary updater.
:return: A tuple containing the time vector and state trajectory matrix.
"""
t_start = time.time()
method = self.method
converged: bool = True
well_initialized: bool = True
if method not in self.jit_kernels_ad:
self.build_jit_ad()
else:
pass
kernel_list = self.jit_kernels_ad[method]
current_jacobian = self.jit_jacobian_ad[method]
n_eqs = self.problem.get_all_vars_number()
n_states = self.problem.get_states_number()
use_dense_solver = (n_eqs <= self.dense_threshold)
t = np.linspace(self.t0, self.t_end, self.steps + 1, dtype=np.float64)
self.t = t
if x0 is None:
x0 = self.problem.get_x0()
else:
pass
if dx0 is None:
dx0 = self.problem.get_dx0()
else:
pass
active_boundary_updater = resolve_solver_boundary_updater(self.problem, boundary_updater, float(self.t0))
self.y[0, :] = x0.copy()
self.dy[0, :] = dx0.copy()
x_prev = x0.copy()
dx_prev = dx0.copy()
x_prev2 = x0.copy()
x_iter = x0.copy()
trial_x = self._trial_state_buffer
trial_res = self._trial_residual_buffer
static_vals = self._static_parameter_buffer
n_event_params = self._runtime_param_count
full_params = self._full_parameter_buffer
residual_buffer = self._residual_buffer
if params0 is not None:
ev_params = np.array(params0, dtype=np.float64, copy=True)
else:
ev_params = self.problem.event_params_values.copy()
should_initialize_at_t0: bool = is_problem_owned_boundary_updater(self.problem, active_boundary_updater)
if active_boundary_updater is not None and should_initialize_at_t0:
runtime_params0 = self.problem.def_event_params_fn(ev_params, float(self.t0))
fill_full_parameter_buffer(runtime_params0, static_vals, full_params)
active_boundary_updater.update(float(self.t0), x_prev, full_params)
ev_params[:] = full_params[:n_event_params]
else:
pass
trace_collector = self.problem.get_newton_trace_collector()
diag_cfg = self._newton_diag_config
diagnostics_enabled: bool = (
trace_collector is not None
or diag_cfg.compute_dense_cond
or diag_cfg.enable_fallback
or diag_cfg.enable_index1_check
or diag_cfg.enable_backtracking
)
dense_solve = None
sparse_solve = None
if diagnostics_enabled:
dense_solve = with_newton_diagnostics(
np.linalg.solve,
fallback_solve=dense_lstsq_fallback,
collector=trace_collector,
config=diag_cfg,
solver_name="dense",
)
sparse_solve = with_newton_diagnostics(
sp.linalg.spsolve,
fallback_solve=sparse_lsqr_fallback,
collector=trace_collector,
config=diag_cfg,
solver_name="sparse",
)
else:
pass
backtracking_residual_evaluator = AdBacktrackingResidualEvaluator()
if self.verbose:
print(f"-> Starting JIT (AD) Simulation ({method}, {self.steps} steps)...")
else:
pass
loop_start: float = time.perf_counter()
total_newton_iterations: int = 0
jacobian_evaluation_count: int = 0
aligned_substep_count: int = 0
for i in range(self.steps):
t_step_start = float(t[i])
t_step_target = float(t[i + 1])
t_local_prev = t_step_start
is_first_local_step = True
while t_local_prev < (t_step_target - 1e-15):
forced_event_time = get_solver_forced_event_time(
active_boundary_updater,
float(t_local_prev),
float(t_step_target),
)
if forced_event_time is None:
t_curr = t_step_target
else:
t_curr = min(float(forced_event_time), t_step_target)
if t_curr <= t_local_prev + 1e-15:
t_curr = t_step_target
is_aligned_substep = t_curr < (t_step_target - 1e-15)
if is_aligned_substep:
aligned_substep_count += 1
else:
aligned_substep_count = aligned_substep_count
if method == DynamicIntegrationMethod.DaeBDF2 and is_aligned_substep:
raise NotImplementedError(
"force_step_alignment is not implemented for DaeBDF2 in JitAdSolver."
)
else:
pass
h_eff = float(t_curr - t_local_prev)
# Parameter update via Problem
runtime_params = self.problem.def_event_params_fn(ev_params, float(t_curr))
fill_full_parameter_buffer(runtime_params, static_vals, full_params)
if active_boundary_updater is not None:
active_boundary_updater.update(t_curr, x_prev, full_params)
else:
pass
ev_params[:] = full_params[:n_event_params]
x_iter[:] = x_prev
# PREDICTOR
if i == 0 and is_first_local_step and method == DynamicIntegrationMethod.DaeTrapezoidal:
self._predictor.predict(
x_iter=x_iter,
x_prev=x_prev,
dx_prev=dx_prev,
h=h_eff,
pred_method=self.pred_method,
)
else:
pass
substep_converged: bool = False
for k in range(self.newton_max_iter):
total_newton_iterations += 1
ctx = NewtonSolveContext(
t=float(t_curr),
step_idx=int(i),
newton_iter=int(k),
phase="jit_ad",
method=str(method),
)
# The residual buffer is reused across Newton iterations so
# the hot loop does not allocate one full residual vector.
res = residual_buffer
res_norm_inf = evaluate_batched_residual(
kernel_list=kernel_list,
x_iter=x_iter,
full_params=full_params,
x_prev=x_prev,
dx_prev=dx_prev,
h_eff=h_eff,
x_prev2=x_prev2,
residual_out=res,
)
if res_norm_inf < 1e-5:
substep_converged = True
break
else:
pass
# Evaluate AD Jacobian
J = current_jacobian(
states=x_iter,
params=full_params,
history=x_prev,
d_history=dx_prev,
h=h_eff,
history2=x_prev2,
)
jacobian_evaluation_count += 1
if diagnostics_enabled:
maybe_check_index1(J, n_states, ctx=ctx, config=diag_cfg)
if use_dense_solver:
dense_J = J.toarray()
if diagnostics_enabled:
delta = dense_solve(dense_J, -res, ctx)
else:
delta = np.linalg.solve(dense_J, -res)
else:
if diagnostics_enabled:
delta = sparse_solve(J, -res, ctx)
else:
delta = sp.linalg.spsolve(J, -res)
if diag_cfg.enable_backtracking:
backtracking_residual_evaluator.set_context(
kernel_list=kernel_list,
full_params=full_params,
x_prev=x_prev,
dx_prev=dx_prev,
h_eff=h_eff,
x_prev2=x_prev2,
)
maybe_apply_backtracking(
x_iter,
delta,
res_norm_inf,
trial_x,
trial_res,
evaluate_residual=backtracking_residual_evaluator.evaluate,
config=diag_cfg,
)
else:
x_iter += delta
if not substep_converged:
converged = False
if i == 0 and is_first_local_step:
well_initialized = False
if method == DynamicIntegrationMethod.DaeTrapezoidal:
dx_prev[:n_states] = (
(2.0 / h_eff) * (x_iter[:n_states] - x_prev[:n_states]) - dx_prev[:n_states]
)
elif method == DynamicIntegrationMethod.DaeBDF2:
x_prev2 = x_prev.copy()
else:
dx_prev[:n_states] = (
(x_iter[:n_states] - x_prev[:n_states]) / h_eff
)
if method != DynamicIntegrationMethod.DaeBDF2:
x_prev2 = x_prev.copy()
else:
pass
x_prev[:] = x_iter
t_local_prev = t_curr
is_first_local_step = False
self.y[i + 1, :] = x_prev
self.dy[i + 1, :] = dx_prev
self._last_sim_loop_time = time.perf_counter() - loop_start
self._last_runtime_stats = dict(
sim_loop_s=self._last_sim_loop_time,
total_newton_iterations=float(total_newton_iterations),
jacobian_evaluations=float(jacobian_evaluation_count),
aligned_substeps=float(aligned_substep_count),
macro_steps=float(self.steps),
)
if self.verbose:
print(f"JIT (AD) Finished: {time.time() - t_start:.4f}s")
else:
pass
return self.t, self.y, self.dy, well_initialized, converged
[docs]
def get_backend_build_stats(self) -> Dict[str, float]:
"""
Return setup timings collected during backend compilation.
:return: Backend build statistics.
:rtype: Dict[str, float]
"""
return dict(self._backend_build_stats)
[docs]
def get_last_runtime_stats(self) -> Dict[str, float]:
"""
Return runtime statistics collected during the latest simulation.
:return: Runtime statistics.
:rtype: Dict[str, float]
"""
return dict(self._last_runtime_stats)
[docs]
def get_last_sim_loop_time(self) -> float:
"""
Return the latest integration-loop wall time.
:return: Latest loop time in seconds.
:rtype: float
"""
return self._last_sim_loop_time