Source code for VeraGridEngine.Simulations.EMT.JMARTI_Sim.jmarti_runtime

# SPDX-License-Identifier: MPL-2.0

from __future__ import annotations

from typing import Any, List, Tuple

import numpy as np

from VeraGridEngine.Devices.Branches.line import Line
from VeraGridEngine.Simulations.EMT.JMARTI_Sim.jmarti_fit_bundle import JMartiFitBundle
from VeraGridEngine.Simulations.EMT.JMARTI_Sim.jmarti_runtime_data import JMartiModeRuntimeData
from VeraGridEngine.Simulations.EMT.JMARTI_Sim.jmarti_runtime_data import JMartiRuntimeData
from VeraGridEngine.Simulations.EMT.JMARTI_Sim.jmarti_runtime_data import build_jmarti_runtime_data
from VeraGridEngine.Utils.Symbolic.block import Block, Var
from VeraGridEngine.enumerations import VarPowerFlowReferenceType


[docs] class JMartiHistoryRuntime: """ Runtime companion for a JMARTI line in reduced active-phase space. The first implementation uses a constant modal basis and one scalar fit per mode. The runtime keeps all convolution-like states internally and exposes only phase-domain Norton history current injections back to the EMT solver. """ __slots__ = ( 'line', 'block', 'h', 'runtime_data', 'ph_labels', 'ph_mask', 'phase_idx', 'active_ph', 'm', 'Ih_f', 'Ih_t', 'direct_yc_modal', 'direct_zc_modal', 'direct_yc_phase', 'max_delay_steps', 'buffer_size', 'buf_a_f', 'buf_a_t', 'yc_state_from', 'yc_state_to', 'hres_state_from', 'hres_state_to', 'yc_order_counts', 'hres_order_counts', 'v_f_vars', 'v_t_vars', 'idx_vf', 'idx_vt', 'idx_p_hf', 'idx_p_ht', ) def __init__(self, line: Any, line_block: Block, h: float) -> None: """ Build one JMARTI runtime companion. :param line: Line-like branch device carrying one fit bundle or runtime data. :param line_block: Symbolic JMARTI line block. :param h: EMT time step in seconds. :return: None. """ self.line = line self.block = line_block self.h = float(h) self.runtime_data = _resolve_jmarti_runtime_data(line=line, line_block=line_block, time_step_s=self.h) self.ph_labels: List[str] = list(["N", "A", "B", "C"]) self.ph_mask: np.ndarray = np.asarray([ bool(line.ys.phN), bool(line.ys.phA), bool(line.ys.phB), bool(line.ys.phC), ], dtype=bool) idx_global: np.ndarray = np.where(self.ph_mask)[0] self.phase_idx: List[int] = list(int(i) for i in idx_global) self.active_ph: List[str] = list(self.ph_labels[i] for i in self.phase_idx) self.m: int = len(self.active_ph) self.Ih_f: List[Var] = _extract_jmarti_history_vars(line_block=self.block, prefix=f"Ih_f_{line.name}_", active_ph=self.active_ph, line_name=line.name) self.Ih_t: List[Var] = _extract_jmarti_history_vars(line_block=self.block, prefix=f"Ih_t_{line.name}_", active_ph=self.active_ph, line_name=line.name) self.direct_yc_modal: np.ndarray = np.zeros(self.m, dtype=np.complex128) self.direct_zc_modal: np.ndarray = np.zeros(self.m, dtype=np.complex128) self.direct_yc_phase: np.ndarray = np.zeros((self.m, self.m), dtype=np.complex128) self.max_delay_steps: int = 0 self.buffer_size: int = 0 self.buf_a_f: np.ndarray = np.zeros((1, self.m), dtype=np.complex128) self.buf_a_t: np.ndarray = np.zeros((1, self.m), dtype=np.complex128) self.yc_state_from: np.ndarray = np.zeros((self.m, 1), dtype=np.complex128) self.yc_state_to: np.ndarray = np.zeros((self.m, 1), dtype=np.complex128) self.hres_state_from: np.ndarray = np.zeros((self.m, 1), dtype=np.complex128) self.hres_state_to: np.ndarray = np.zeros((self.m, 1), dtype=np.complex128) self.yc_order_counts: np.ndarray = np.zeros(self.m, dtype=np.int64) self.hres_order_counts: np.ndarray = np.zeros(self.m, dtype=np.int64) self.v_f_vars = None self.v_t_vars = None self.idx_vf = None self.idx_vt = None self.idx_p_hf = None self.idx_p_ht = None _initialize_jmarti_runtime_storage(self)
[docs] def bind_terminals(self, v_f_vars: List[Any], v_t_vars: List[Any]) -> None: """ Bind the bus terminal voltage variables for the active phases only. :param v_f_vars: Full from-side bus voltage variable list in NABC order. :param v_t_vars: Full to-side bus voltage variable list in NABC order. :return: None. """ self.v_f_vars = list() self.v_t_vars = list() phase_list_index: int = 0 phase_name: str from_side_full_layout: bool = len(v_f_vars) == 4 to_side_full_layout: bool = len(v_t_vars) == 4 source_index: int while phase_list_index < len(self.phase_idx): full_index: int = self.phase_idx[phase_list_index] phase_name = self.active_ph[phase_list_index] if from_side_full_layout: source_index = full_index else: source_index = phase_list_index vf_var = v_f_vars[source_index] if to_side_full_layout: source_index = full_index else: source_index = phase_list_index vt_var = v_t_vars[source_index] if vf_var is None: raise ValueError( f"JMARTI line '{self.line.name}' has active phase '{phase_name}' " f"but from-bus '{self.line.bus_from.name}' does not provide the corresponding EMT voltage variable." ) else: pass if vt_var is None: raise ValueError( f"JMARTI line '{self.line.name}' has active phase '{phase_name}' " f"but to-bus '{self.line.bus_to.name}' does not provide the corresponding EMT voltage variable." ) else: pass self.v_f_vars.append(vf_var) self.v_t_vars.append(vt_var) phase_list_index += 1
[docs] def get_nodal_injections(self) -> Tuple[List[Any], List[Any]]: """ Return the phase-domain Norton injections seen by the EMT nodal solver. :return: Tuple ``(i_from_full, i_to_full)`` in fixed NABC order. """ if self.v_f_vars is None or self.v_t_vars is None: raise RuntimeError("bind_terminals(...) must be called before get_nodal_injections().") else: pass i_f_red: List[Any] = list() i_t_red: List[Any] = list() row_index: int = 0 col_index: int expr_f: Any expr_t: Any while row_index < self.m: expr_f = self.Ih_f[row_index] expr_t = self.Ih_t[row_index] col_index = 0 while col_index < self.m: expr_f = expr_f + self.direct_yc_phase[row_index, col_index] * self.v_f_vars[col_index] expr_t = expr_t + self.direct_yc_phase[row_index, col_index] * self.v_t_vars[col_index] col_index += 1 i_f_red.append(expr_f) i_t_red.append(expr_t) row_index += 1 i_f_full: List[Any] = list([None, None, None, None]) i_t_full: List[Any] = list([None, None, None, None]) phase_index: int = 0 while phase_index < len(self.phase_idx): full_index = self.phase_idx[phase_index] i_f_full[full_index] = i_f_red[phase_index] i_t_full[full_index] = i_t_red[phase_index] phase_index += 1 return i_f_full, i_t_full
[docs] def setup_indices(self, uid2idx_vars: dict, uid2idx_event_params: dict, params_offset: int = 0) -> None: """ Bind solver indices used during the history update. :param uid2idx_vars: Variable index map. :param uid2idx_event_params: Event-parameter index map. :param params_offset: Runtime-parameter offset. :return: None. """ self.idx_vf = [uid2idx_vars[v.uid] for v in self.v_f_vars] self.idx_vt = [uid2idx_vars[v.uid] for v in self.v_t_vars] self.idx_p_hf = [uid2idx_event_params[p.uid] + params_offset for p in self.Ih_f] self.idx_p_ht = [uid2idx_event_params[p.uid] + params_offset for p in self.Ih_t]
[docs] def get_mode_count(self) -> int: """ Return the number of modal channels carried by the runtime. :return: Number of modal channels. """ return self.m
[docs] def initialize_from_initial_point(self, v_f0_red: np.ndarray, v_t0_red: np.ndarray, i_f0_red: np.ndarray, i_t0_red: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ Initialize buffers and history parameters from one steady-state point. :param v_f0_red: From-side active-phase voltages in phase coordinates. :param v_t0_red: To-side active-phase voltages in phase coordinates. :param i_f0_red: From-side active-phase currents in phase coordinates. :param i_t0_red: To-side active-phase currents in phase coordinates. :return: Tuple ``(ih_f_phase, ih_t_phase)`` used to seed event parameters. """ v_f0_modal: np.ndarray = self.runtime_data.get_modal_transform_inv() @ np.asarray(v_f0_red, dtype=np.complex128) v_t0_modal: np.ndarray = self.runtime_data.get_modal_transform_inv() @ np.asarray(v_t0_red, dtype=np.complex128) i_f0_modal: np.ndarray = self.runtime_data.get_modal_transform_inv() @ np.asarray(i_f0_red, dtype=np.complex128) i_t0_modal: np.ndarray = self.runtime_data.get_modal_transform_inv() @ np.asarray(i_t0_red, dtype=np.complex128) a_f0_modal: np.ndarray = 0.5 * (v_f0_modal + self.direct_zc_modal * i_f0_modal) a_t0_modal: np.ndarray = 0.5 * (v_t0_modal + self.direct_zc_modal * i_t0_modal) ih_f_modal: np.ndarray = np.zeros(self.m, dtype=np.complex128) ih_t_modal: np.ndarray = np.zeros(self.m, dtype=np.complex128) buffer_index: int = 0 mode_index: int = 0 mode_data: JMartiModeRuntimeData current_hres_from: complex current_hres_to: complex yc_input_from: complex yc_input_to: complex yc_order_count: int hres_order_count: int while buffer_index < self.buffer_size: self.buf_a_f[buffer_index, :] = a_f0_modal self.buf_a_t[buffer_index, :] = a_t0_modal buffer_index += 1 self.yc_state_from.fill(0.0) self.yc_state_to.fill(0.0) self.hres_state_from.fill(0.0) self.hres_state_to.fill(0.0) while mode_index < self.m: mode_data = self.runtime_data.get_mode_data()[mode_index] yc_order_count = int(self.yc_order_counts[mode_index]) hres_order_count = int(self.hres_order_counts[mode_index]) if hres_order_count > 0: self.hres_state_from[mode_index, :hres_order_count] = _build_jmarti_steady_state_state_row( alpha=mode_data.get_hres_alpha(), beta=mode_data.get_hres_beta(), input_value=a_t0_modal[mode_index], ) self.hres_state_to[mode_index, :hres_order_count] = _build_jmarti_steady_state_state_row( alpha=mode_data.get_hres_alpha(), beta=mode_data.get_hres_beta(), input_value=a_f0_modal[mode_index], ) else: pass current_hres_from = ( mode_data.get_hres_constant_term() * a_t0_modal[mode_index] + _sum_state_row(self.hres_state_from, mode_index, hres_order_count) ) current_hres_to = ( mode_data.get_hres_constant_term() * a_f0_modal[mode_index] + _sum_state_row(self.hres_state_to, mode_index, hres_order_count) ) yc_input_from = v_f0_modal[mode_index] - 2.0 * current_hres_from yc_input_to = v_t0_modal[mode_index] - 2.0 * current_hres_to if yc_order_count > 0: self.yc_state_from[mode_index, :yc_order_count] = _build_jmarti_steady_state_state_row( alpha=mode_data.get_yc_alpha(), beta=mode_data.get_yc_beta(), input_value=yc_input_from, ) self.yc_state_to[mode_index, :yc_order_count] = _build_jmarti_steady_state_state_row( alpha=mode_data.get_yc_alpha(), beta=mode_data.get_yc_beta(), input_value=yc_input_to, ) else: pass ih_f_modal[mode_index] = _sum_state_row(self.yc_state_from, mode_index, yc_order_count) - 2.0 * self.direct_yc_modal[mode_index] * current_hres_from ih_t_modal[mode_index] = _sum_state_row(self.yc_state_to, mode_index, yc_order_count) - 2.0 * self.direct_yc_modal[mode_index] * current_hres_to mode_index += 1 return ( self.runtime_data.get_modal_transform() @ ih_f_modal, self.runtime_data.get_modal_transform() @ ih_t_modal, )
[docs] def initialize_from_fundamental_phasors(self, v_f0_phasor_red: np.ndarray, v_t0_phasor_red: np.ndarray, i_f0_phasor_red: np.ndarray, i_t0_phasor_red: np.ndarray, system_frequency_hz: float) -> tuple[np.ndarray, np.ndarray]: """ Initialize one JMARTI runtime from one sinusoidal fundamental operating point. The runtime states represent first-order discrete filters, so their consistent steady state under a balanced sinusoidal operating point is not the constant-input equilibrium. This initializer seeds the internal modal states with the periodic discrete-time solution at the requested fundamental frequency. :param v_f0_phasor_red: From-side phase phasors in RMS complex form. :param v_t0_phasor_red: To-side phase phasors in RMS complex form. :param i_f0_phasor_red: From-side branch-current phasors in RMS complex form. :param i_t0_phasor_red: To-side branch-current phasors in RMS complex form. :param system_frequency_hz: Fundamental system frequency in Hz. :return: Tuple ``(ih_f_phase, ih_t_phase)`` used to seed event parameters. """ angular_frequency_rad_per_s: float = 2.0 * np.pi * float(system_frequency_hz) lambda_step: complex = complex(np.exp(1j * angular_frequency_rad_per_s * self.h)) v_f0_analytic: np.ndarray = -1j * np.sqrt(2.0) * np.asarray(v_f0_phasor_red, dtype=np.complex128) v_t0_analytic: np.ndarray = -1j * np.sqrt(2.0) * np.asarray(v_t0_phasor_red, dtype=np.complex128) i_f0_analytic: np.ndarray = -1j * np.sqrt(2.0) * np.asarray(i_f0_phasor_red, dtype=np.complex128) i_t0_analytic: np.ndarray = -1j * np.sqrt(2.0) * np.asarray(i_t0_phasor_red, dtype=np.complex128) v_f0_sample: np.ndarray = np.sqrt(2.0) * np.imag(np.asarray(v_f0_phasor_red, dtype=np.complex128)) v_t0_sample: np.ndarray = np.sqrt(2.0) * np.imag(np.asarray(v_t0_phasor_red, dtype=np.complex128)) i_f0_sample: np.ndarray = np.sqrt(2.0) * np.imag(np.asarray(i_f0_phasor_red, dtype=np.complex128)) i_t0_sample: np.ndarray = np.sqrt(2.0) * np.imag(np.asarray(i_t0_phasor_red, dtype=np.complex128)) v_f0_modal: np.ndarray = self.runtime_data.get_modal_transform_inv() @ v_f0_analytic v_t0_modal: np.ndarray = self.runtime_data.get_modal_transform_inv() @ v_t0_analytic i_f0_modal: np.ndarray = self.runtime_data.get_modal_transform_inv() @ i_f0_analytic i_t0_modal: np.ndarray = self.runtime_data.get_modal_transform_inv() @ i_t0_analytic a_f0_modal: np.ndarray = 0.5 * (v_f0_modal + self.direct_zc_modal * i_f0_modal) a_t0_modal: np.ndarray = 0.5 * (v_t0_modal + self.direct_zc_modal * i_t0_modal) ih_f_modal: np.ndarray = np.zeros(self.m, dtype=np.complex128) ih_t_modal: np.ndarray = np.zeros(self.m, dtype=np.complex128) buffer_index: int = 0 mode_index: int = 0 mode_data: JMartiModeRuntimeData yc_order_count: int hres_order_count: int hres_output_from: complex hres_output_to: complex yc_input_from: complex yc_input_to: complex yc_dynamic_from: complex yc_dynamic_to: complex while buffer_index < self.buffer_size: self.buf_a_f[buffer_index, :] = a_f0_modal self.buf_a_t[buffer_index, :] = a_t0_modal buffer_index += 1 self.yc_state_from.fill(0.0) self.yc_state_to.fill(0.0) self.hres_state_from.fill(0.0) self.hres_state_to.fill(0.0) while mode_index < self.m: mode_data = self.runtime_data.get_mode_data()[mode_index] yc_order_count = int(self.yc_order_counts[mode_index]) hres_order_count = int(self.hres_order_counts[mode_index]) if hres_order_count > 0: self.hres_state_from[mode_index, :hres_order_count] = _build_jmarti_periodic_state_row( alpha=mode_data.get_hres_alpha(), beta=mode_data.get_hres_beta(), input_value=a_t0_modal[mode_index], lambda_step=lambda_step, ) self.hres_state_to[mode_index, :hres_order_count] = _build_jmarti_periodic_state_row( alpha=mode_data.get_hres_alpha(), beta=mode_data.get_hres_beta(), input_value=a_f0_modal[mode_index], lambda_step=lambda_step, ) else: pass hres_output_from = ( mode_data.get_hres_constant_term() * a_t0_modal[mode_index] + _sum_state_row(self.hres_state_from, mode_index, hres_order_count) ) hres_output_to = ( mode_data.get_hres_constant_term() * a_f0_modal[mode_index] + _sum_state_row(self.hres_state_to, mode_index, hres_order_count) ) yc_input_from = v_f0_modal[mode_index] - 2.0 * hres_output_from yc_input_to = v_t0_modal[mode_index] - 2.0 * hres_output_to if yc_order_count > 0: self.yc_state_from[mode_index, :yc_order_count] = _build_jmarti_periodic_state_row( alpha=mode_data.get_yc_alpha(), beta=mode_data.get_yc_beta(), input_value=yc_input_from, lambda_step=lambda_step, ) self.yc_state_to[mode_index, :yc_order_count] = _build_jmarti_periodic_state_row( alpha=mode_data.get_yc_alpha(), beta=mode_data.get_yc_beta(), input_value=yc_input_to, lambda_step=lambda_step, ) else: pass yc_dynamic_from = _sum_state_row(self.yc_state_from, mode_index, yc_order_count) yc_dynamic_to = _sum_state_row(self.yc_state_to, mode_index, yc_order_count) ih_f_modal[mode_index] = yc_dynamic_from - 2.0 * self.direct_yc_modal[mode_index] * hres_output_from ih_t_modal[mode_index] = yc_dynamic_to - 2.0 * self.direct_yc_modal[mode_index] * hres_output_to mode_index += 1 ih_f_phase: np.ndarray = np.asarray(i_f0_sample - self.direct_yc_phase @ v_f0_sample, dtype=np.complex128) ih_t_phase: np.ndarray = np.asarray(i_t0_sample - self.direct_yc_phase @ v_t0_sample, dtype=np.complex128) return ih_f_phase, ih_t_phase
[docs] def update_history(self, step_counter: int, x_prev: np.ndarray, full_params: np.ndarray) -> None: """ Update the retained JMARTI history injections after one accepted EMT step. :param step_counter: Current accepted step counter. :param x_prev: Previous accepted state vector. :param full_params: Flat runtime parameter vector. :return: None. """ v_f_now_phase: np.ndarray = np.asarray([x_prev[index] if index >= 0 else 0.0 for index in self.idx_vf], dtype=np.complex128) v_t_now_phase: np.ndarray = np.asarray([x_prev[index] if index >= 0 else 0.0 for index in self.idx_vt], dtype=np.complex128) ih_f_now_phase: np.ndarray = np.asarray([full_params[index] for index in self.idx_p_hf], dtype=np.complex128) ih_t_now_phase: np.ndarray = np.asarray([full_params[index] for index in self.idx_p_ht], dtype=np.complex128) i_f_now_phase: np.ndarray = self.direct_yc_phase @ v_f_now_phase + ih_f_now_phase i_t_now_phase: np.ndarray = self.direct_yc_phase @ v_t_now_phase + ih_t_now_phase v_f_now_modal: np.ndarray = self.runtime_data.get_modal_transform_inv() @ v_f_now_phase v_t_now_modal: np.ndarray = self.runtime_data.get_modal_transform_inv() @ v_t_now_phase i_f_now_modal: np.ndarray = self.runtime_data.get_modal_transform_inv() @ i_f_now_phase i_t_now_modal: np.ndarray = self.runtime_data.get_modal_transform_inv() @ i_t_now_phase a_f_now_modal: np.ndarray = 0.5 * (v_f_now_modal + self.direct_zc_modal * i_f_now_modal) a_t_now_modal: np.ndarray = 0.5 * (v_t_now_modal + self.direct_zc_modal * i_t_now_modal) current_buffer_index: int = step_counter % self.buffer_size self.buf_a_f[current_buffer_index, :] = a_f_now_modal self.buf_a_t[current_buffer_index, :] = a_t_now_modal ih_f_next_modal: np.ndarray = np.zeros(self.m, dtype=np.complex128) ih_t_next_modal: np.ndarray = np.zeros(self.m, dtype=np.complex128) mode_index: int = 0 while mode_index < self.m: ih_f_next_modal[mode_index] = _update_jmarti_mode_history( runtime=self, mode_index=mode_index, step_counter=step_counter, local_voltage_modal=v_f_now_modal[mode_index], remote_incident_buffer=self.buf_a_t, yc_state=self.yc_state_from, hres_state=self.hres_state_from, ) ih_t_next_modal[mode_index] = _update_jmarti_mode_history( runtime=self, mode_index=mode_index, step_counter=step_counter, local_voltage_modal=v_t_now_modal[mode_index], remote_incident_buffer=self.buf_a_f, yc_state=self.yc_state_to, hres_state=self.hres_state_to, ) mode_index += 1 ih_f_next_phase: np.ndarray = self.runtime_data.get_modal_transform() @ ih_f_next_modal ih_t_next_phase: np.ndarray = self.runtime_data.get_modal_transform() @ ih_t_next_modal phase_index: int = 0 while phase_index < self.m: full_params[self.idx_p_hf[phase_index]] = float(np.real(ih_f_next_phase[phase_index])) full_params[self.idx_p_ht[phase_index]] = float(np.real(ih_t_next_phase[phase_index])) phase_index += 1
JMARTI_BLOCK_FIT_BUNDLE_ATTR: str = "_jmarti_fit_bundle" JMARTI_BLOCK_RUNTIME_DATA_ATTR: str = "_jmarti_runtime_data"
[docs] def set_jmarti_block_fit_bundle(block: Block, fit_bundle: JMartiFitBundle | None) -> None: """ Attach one optional JMARTI fit bundle to one EMT block. :param block: Target EMT block. :param fit_bundle: Optional offline fit bundle. :return: None. """ if fit_bundle is None: block.__dict__.pop(JMARTI_BLOCK_FIT_BUNDLE_ATTR, None) else: block.__dict__[JMARTI_BLOCK_FIT_BUNDLE_ATTR] = fit_bundle
[docs] def get_jmarti_block_fit_bundle(block: Block | None) -> JMartiFitBundle | None: """ Return the optional JMARTI fit bundle attached to one EMT block. :param block: Candidate EMT block. :return: Attached fit bundle or ``None``. """ if isinstance(block, Block): fit_bundle = block.__dict__.get(JMARTI_BLOCK_FIT_BUNDLE_ATTR, None) if isinstance(fit_bundle, JMartiFitBundle): return fit_bundle else: return None else: return None
[docs] def set_jmarti_block_runtime_data(block: Block, runtime_data: JMartiRuntimeData | None) -> None: """ Attach one optional discretized JMARTI runtime data object to one EMT block. :param block: Target EMT block. :param runtime_data: Optional runtime data. :return: None. """ if runtime_data is None: block.__dict__.pop(JMARTI_BLOCK_RUNTIME_DATA_ATTR, None) else: block.__dict__[JMARTI_BLOCK_RUNTIME_DATA_ATTR] = runtime_data
[docs] def get_jmarti_block_runtime_data(block: Block | None) -> JMartiRuntimeData | None: """ Return the optional JMARTI runtime data attached to one EMT block. :param block: Candidate EMT block. :return: Attached runtime data or ``None``. """ if isinstance(block, Block): runtime_data = block.__dict__.get(JMARTI_BLOCK_RUNTIME_DATA_ATTR, None) if isinstance(runtime_data, JMartiRuntimeData): return runtime_data else: return None else: return None
def _resolve_jmarti_runtime_data(line: Any, line_block: Block, time_step_s: float) -> JMartiRuntimeData: """ Resolve one discrete JMARTI runtime dataset from the line object. :param line: Line-like branch device. :param line_block: EMT block carrying the JMARTI configuration and optional fit. :param time_step_s: EMT time step in seconds. :return: Runtime-ready JMARTI dataset. :raises ValueError: If no JMARTI fit information is attached to the line. """ if isinstance(line, Line) and isinstance(line_block, Block): attached_runtime_data: JMartiRuntimeData | None = get_jmarti_block_runtime_data(line_block) if attached_runtime_data is not None: return attached_runtime_data else: attached_fit_bundle: JMartiFitBundle | None = get_jmarti_block_fit_bundle(line_block) if attached_fit_bundle is not None: runtime_data: JMartiRuntimeData = build_jmarti_runtime_data(attached_fit_bundle, time_step_s) set_jmarti_block_runtime_data(line_block, runtime_data) return runtime_data else: raise ValueError(f"JMARTI line '{line.name}' requires one fit bundle or one runtime data object attached to its EMT block") else: raise ValueError("JMARTI runtime expects a Line-compatible object") def _extract_jmarti_history_vars(line_block: Block, prefix: str, active_ph: List[str], line_name: str) -> List[Var]: """ Extract one ordered list of history-current event variables. :param line_block: JMARTI symbolic line block. :param prefix: History-variable prefix. :param active_ph: Active phase labels. :param line_name: Line name used in diagnostics. :return: Ordered history-variable list. """ vars_by_name: dict[str, Var] = dict((var.name, var) for var in line_block.event_dict.keys()) out: List[Var] = list() phase_index: int = 0 key_name: str phase_label: str while phase_index < len(active_ph): phase_label = active_ph[phase_index] key_name = f"{prefix}{phase_label}" if key_name in vars_by_name: out.append(vars_by_name[key_name]) else: raise ValueError(f"Missing JMARTI history var '{key_name}' for line '{line_name}'") phase_index += 1 return out def _initialize_jmarti_runtime_storage(runtime: JMartiHistoryRuntime) -> None: """ Allocate the numerical storage used by one JMARTI runtime. :param runtime: Target JMARTI runtime. :return: None. """ mode_index: int = 0 max_yc_order: int = 0 max_hres_order: int = 0 mode_data: JMartiModeRuntimeData modal_direct_terms: np.ndarray = np.zeros(runtime.m, dtype=np.complex128) while mode_index < runtime.m: mode_data = runtime.runtime_data.get_mode_data()[mode_index] runtime.yc_order_counts[mode_index] = mode_data.get_yc_alpha().size runtime.hres_order_counts[mode_index] = mode_data.get_hres_alpha().size modal_direct_terms[mode_index] = mode_data.get_yc_constant_term() max_yc_order = max(max_yc_order, int(runtime.yc_order_counts[mode_index])) max_hres_order = max(max_hres_order, int(runtime.hres_order_counts[mode_index])) mode_index += 1 runtime.direct_yc_modal = modal_direct_terms runtime.direct_zc_modal = np.zeros(runtime.m, dtype=np.complex128) non_zero_mask: np.ndarray = np.abs(runtime.direct_yc_modal) > 1.0e-14 runtime.direct_zc_modal[non_zero_mask] = 1.0 / runtime.direct_yc_modal[non_zero_mask] runtime.direct_yc_phase = np.real( runtime.runtime_data.get_modal_transform() @ np.diag(runtime.direct_yc_modal) @ runtime.runtime_data.get_modal_transform_inv() ) runtime.max_delay_steps = max(int(mode_data.get_delay_step_count()) for mode_data in runtime.runtime_data.get_mode_data()) runtime.buffer_size = max(2, runtime.max_delay_steps + 2) runtime.buf_a_f = np.zeros((runtime.buffer_size, runtime.m), dtype=np.complex128) runtime.buf_a_t = np.zeros((runtime.buffer_size, runtime.m), dtype=np.complex128) runtime.yc_state_from = np.zeros((runtime.m, max(1, max_yc_order)), dtype=np.complex128) runtime.yc_state_to = np.zeros((runtime.m, max(1, max_yc_order)), dtype=np.complex128) runtime.hres_state_from = np.zeros((runtime.m, max(1, max_hres_order)), dtype=np.complex128) runtime.hres_state_to = np.zeros((runtime.m, max(1, max_hres_order)), dtype=np.complex128) def _get_delayed_incident_wave(runtime: JMartiHistoryRuntime, mode_index: int, step_counter: int, remote_incident_buffer: np.ndarray, next_step: bool) -> complex: """ Return one delayed outgoing wave sample for the requested mode. :param runtime: Owning JMARTI runtime. :param mode_index: Modal channel index. :param step_counter: Current accepted step counter. :param remote_incident_buffer: Ring buffer of remote outgoing waves. :param next_step: Whether to sample the delay for the next step. :return: Delayed outgoing wave. """ mode_data: JMartiModeRuntimeData = runtime.runtime_data.get_mode_data()[mode_index] delay_step_count: int = mode_data.get_delay_step_count() residual_delay_s: float = mode_data.get_residual_delay_s() interpolation_fraction_previous: float = 0.0 interpolation_fraction_current: float = 1.0 delayed_step_index: int previous_step_index: int if runtime.h > 0.0: interpolation_fraction_previous = min(max(residual_delay_s / runtime.h, 0.0), 1.0) else: interpolation_fraction_previous = 0.0 interpolation_fraction_current = 1.0 - interpolation_fraction_previous if delay_step_count == 0: delayed_step_index = step_counter + 1 if next_step else step_counter else: if next_step: delayed_step_index = step_counter + 1 - delay_step_count else: delayed_step_index = step_counter - delay_step_count if interpolation_fraction_previous > 0.0: previous_step_index = delayed_step_index - 1 return complex( interpolation_fraction_previous * remote_incident_buffer[previous_step_index % runtime.buffer_size, mode_index] + interpolation_fraction_current * remote_incident_buffer[delayed_step_index % runtime.buffer_size, mode_index] ) else: return complex(remote_incident_buffer[delayed_step_index % runtime.buffer_size, mode_index]) def _sum_state_row(state_matrix: np.ndarray, mode_index: int, order_count: int) -> complex: """ Return the sum of one active row segment in a padded state matrix. :param state_matrix: Padded state matrix. :param mode_index: Modal row index. :param order_count: Active order in that row. :return: Sum of the active row entries. """ if order_count > 0: return complex(np.sum(state_matrix[mode_index, :order_count])) else: return 0.0 + 0.0j def _update_jmarti_state_row(state_matrix: np.ndarray, mode_index: int, order_count: int, alpha: np.ndarray, beta: np.ndarray, input_value: complex) -> None: """ Update one padded row of first-order discrete states. :param state_matrix: Padded state matrix. :param mode_index: Modal row index. :param order_count: Active order in that row. :param alpha: Active transition multipliers. :param beta: Active input gains. :param input_value: Current scalar input. :return: None. """ if order_count > 0: state_matrix[mode_index, :order_count] = alpha * state_matrix[mode_index, :order_count] + beta * input_value else: pass def _build_jmarti_steady_state_state_row(alpha: np.ndarray, beta: np.ndarray, input_value: complex) -> np.ndarray: """ Return the discrete steady-state of one first-order JMARTI state row. :param alpha: Exact state-transition multipliers. :param beta: Exact input gains. :param input_value: Constant scalar input. :return: Steady-state row values. """ denominator: np.ndarray = 1.0 - np.asarray(alpha, dtype=np.complex128) steady_state: np.ndarray = np.zeros_like(np.asarray(beta, dtype=np.complex128)) non_zero_mask: np.ndarray = np.abs(denominator) > 1.0e-14 steady_state[non_zero_mask] = np.asarray(beta, dtype=np.complex128)[non_zero_mask] * input_value / denominator[non_zero_mask] return steady_state def _build_jmarti_periodic_state_row(alpha: np.ndarray, beta: np.ndarray, input_value: complex, lambda_step: complex) -> np.ndarray: """ Return the discrete periodic state of one first-order JMARTI state row. :param alpha: Exact state-transition multipliers. :param beta: Exact input gains. :param input_value: Fundamental-frequency analytic input sample at ``n=0``. :param lambda_step: One-step complex phase advance ``exp(j*w*h)``. :return: Periodic steady-state row values. """ denominator: np.ndarray = complex(lambda_step) - np.asarray(alpha, dtype=np.complex128) periodic_state: np.ndarray = np.zeros_like(np.asarray(beta, dtype=np.complex128)) non_zero_mask: np.ndarray = np.abs(denominator) > 1.0e-14 periodic_state[non_zero_mask] = np.asarray(beta, dtype=np.complex128)[non_zero_mask] * input_value / denominator[non_zero_mask] return periodic_state def _update_jmarti_mode_history(runtime: JMartiHistoryRuntime, mode_index: int, step_counter: int, local_voltage_modal: complex, remote_incident_buffer: np.ndarray, yc_state: np.ndarray, hres_state: np.ndarray) -> complex: """ Advance one modal channel and return the next-step history current. :param runtime: Owning JMARTI runtime. :param mode_index: Modal channel index. :param step_counter: Current accepted step counter. :param local_voltage_modal: Local terminal voltage in modal coordinates. :param remote_incident_buffer: Ring buffer of remote outgoing waves. :param yc_state: Padded local Yc state matrix. :param hres_state: Padded propagation-filter state matrix. :return: Next-step history current in modal coordinates. """ mode_data: JMartiModeRuntimeData = runtime.runtime_data.get_mode_data()[mode_index] yc_order_count: int = int(runtime.yc_order_counts[mode_index]) hres_order_count: int = int(runtime.hres_order_counts[mode_index]) delayed_wave_current: complex = _get_delayed_incident_wave(runtime, mode_index, step_counter, remote_incident_buffer, next_step=False) current_hres_output: complex = mode_data.get_hres_constant_term() * delayed_wave_current + _sum_state_row(hres_state, mode_index, hres_order_count) yc_input_current: complex = local_voltage_modal - 2.0 * current_hres_output _update_jmarti_state_row( state_matrix=hres_state, mode_index=mode_index, order_count=hres_order_count, alpha=mode_data.get_hres_alpha(), beta=mode_data.get_hres_beta(), input_value=delayed_wave_current, ) _update_jmarti_state_row( state_matrix=yc_state, mode_index=mode_index, order_count=yc_order_count, alpha=mode_data.get_yc_alpha(), beta=mode_data.get_yc_beta(), input_value=yc_input_current, ) delayed_wave_next: complex = _get_delayed_incident_wave(runtime, mode_index, step_counter, remote_incident_buffer, next_step=True) next_hres_output: complex = mode_data.get_hres_constant_term() * delayed_wave_next + _sum_state_row(hres_state, mode_index, hres_order_count) next_yc_dynamic_output: complex = _sum_state_row(yc_state, mode_index, yc_order_count) return next_yc_dynamic_output - 2.0 * runtime.direct_yc_modal[mode_index] * next_hres_output