"""Core animation routines for :mod:`flipbook`."""
from __future__ import annotations
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable
import matplotlib.animation as mpl_animation
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import PolyCollection
from matplotlib.colors import Normalize
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
from numpy.typing import ArrayLike, NDArray
from .types import Array1D, Array2D, Chain, StepSelection, WalkerSelection, as_chain
from .utils.selection import (
compute_percentile_bands,
resolve_step_indices,
resolve_walker_indices,
select_topk_by_log_prob,
)
try:
from tqdm.auto import tqdm
except Exception: # pragma: no cover - tqdm is an optional dependency
tqdm = None
@dataclass
class _FrameResult:
"""Container holding all information required to render a frame."""
step_index: int
walker_indices: NDArray[np.int_]
curves: Array2D
aggregate: Array1D | None
percentile_bands: dict[str, tuple[Array1D, Array1D]]
colors: list[tuple[float, float, float, float]] | None
log_prob: Array1D | None
def _apply_param_transform(
parameters: Array2D,
transform: Callable[[Array1D], Array1D] | Callable[[Array2D], Array2D] | None,
) -> Array2D:
"""Apply an optional parameter transformation."""
if transform is None:
return parameters
try:
transformed = np.asarray(transform(parameters))
except Exception:
transformed = None
if transformed is not None and transformed.shape == parameters.shape:
return transformed.astype(float, copy=False)
transformed_rows = np.empty_like(parameters, dtype=float)
for idx, theta in enumerate(parameters):
transformed_rows[idx] = np.asarray(transform(theta), dtype=float)
return transformed_rows
def _hash_theta(theta: Array1D) -> int:
"""Compute a stable hash for a parameter vector."""
return hash(theta.tobytes())
def _evaluate_model_curves(
model_fn: Callable[[Array1D, Array1D], Array1D],
t: Array1D,
parameters: Array2D,
*,
vectorized: bool,
chunk_size: int,
n_jobs: int,
cache: dict[tuple[int, int], Array1D],
) -> Array2D:
"""Evaluate the model for a batch of parameter vectors."""
if parameters.size == 0:
return np.empty((0, t.size), dtype=float)
if vectorized:
try:
candidate = np.asarray(model_fn(parameters, t), dtype=float)
except Exception:
candidate = None
if candidate is not None:
if candidate.shape != (parameters.shape[0], t.size):
raise ValueError("Vectorized model function returned unexpected shape")
if not np.all(np.isfinite(candidate)):
raise ValueError("Model function returned non-finite values")
return candidate
if chunk_size <= 0:
raise ValueError("chunk_size must be a positive integer")
if n_jobs <= 0:
raise ValueError("n_jobs must be a positive integer")
results = np.empty((parameters.shape[0], t.size), dtype=float)
t_identifier = id(t)
def evaluate_single(theta: Array1D) -> Array1D:
key = (_hash_theta(theta), t_identifier)
cached = cache.get(key)
if cached is not None:
return cached.copy()
curve = np.asarray(model_fn(theta, t), dtype=float)
if curve.shape != t.shape:
raise ValueError(
"Model function must return an array matching the shape of 't'"
)
if curve.ndim != 1:
raise ValueError("Model output must be one-dimensional")
if not np.all(np.isfinite(curve)):
raise ValueError("Model function returned non-finite values")
cache[key] = curve.copy()
return curve
if n_jobs == 1 or parameters.shape[0] == 1:
for idx, theta in enumerate(parameters):
results[idx] = evaluate_single(theta)
return results
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=n_jobs) as executor:
start = 0
while start < parameters.shape[0]:
end = min(start + chunk_size, parameters.shape[0])
futures = [
executor.submit(evaluate_single, parameters[i])
for i in range(start, end)
]
for offset, future in enumerate(futures):
results[start + offset] = future.result()
start = end
return results
def _compute_aggregate(curves: Array2D, mode: str | None) -> Array1D | None:
"""Compute per-step aggregate curves."""
if mode is None or curves.size == 0:
return None
if mode == "median":
return np.median(curves, axis=0)
if mode == "mean":
return np.mean(curves, axis=0)
raise ValueError("per_step_aggregate must be one of {'median', 'mean', None}")
def _frame_data_generator(
chain: Chain,
model_fn: Callable[[Array1D, Array1D], Array1D],
t: Array1D,
step_indices: NDArray[np.int_],
walker_indices: NDArray[np.int_],
*,
param_transform: Callable[[Array1D], Array1D] | Callable[[Array2D], Array2D] | None,
vectorized: bool,
chunk_size: int,
n_jobs: int,
topk_by_logp: int | None,
max_curves_per_frame: int | None,
per_step_aggregate: str | None,
percentile_bands: Sequence[float] | None,
color_by: str | None,
progress: bool,
) -> Iterator[_FrameResult]:
"""Yield frame specifications for downstream rendering."""
cache: dict[tuple[int, int], Array1D] = {}
if max_curves_per_frame is not None and max_curves_per_frame <= 0:
raise ValueError("max_curves_per_frame must be positive when provided")
percentile_bands = list(percentile_bands) if percentile_bands is not None else None
if color_by not in {None, "walker", "logp"}:
raise ValueError("color_by must be one of {'walker', 'logp', None}")
color_map = None
norm = None
if color_by == "walker":
color_map = cm.get_cmap("viridis")
norm = Normalize(vmin=0, vmax=max(chain.nwalkers - 1, 1))
elif color_by == "logp":
if chain.log_prob is None:
raise ValueError("log_prob is required when color_by='logp'")
subset = chain.log_prob[np.ix_(step_indices, walker_indices)]
vmin = float(np.min(subset))
vmax = float(np.max(subset))
if not np.isfinite(vmin) or not np.isfinite(vmax):
raise ValueError("log_prob contains non-finite values")
if vmin == vmax:
vmax = vmin + 1.0
norm = Normalize(vmin=vmin, vmax=vmax)
color_map = cm.get_cmap("plasma")
step_iterator: Iterable[int]
if progress and tqdm is not None:
step_iterator = tqdm(step_indices, desc="Generating frames")
else:
step_iterator = step_indices
for step_index in step_iterator:
selected = walker_indices
if topk_by_logp is not None:
selected = select_topk_by_log_prob(
chain.log_prob, step_index, selected, topk_by_logp
)
if max_curves_per_frame is not None and selected.size > max_curves_per_frame:
selected = selected[:max_curves_per_frame]
parameters = chain.chain[step_index, selected, :]
parameters = _apply_param_transform(parameters, param_transform)
curves = _evaluate_model_curves(
model_fn,
t,
parameters,
vectorized=vectorized,
chunk_size=chunk_size,
n_jobs=n_jobs,
cache=cache,
)
aggregate = _compute_aggregate(curves, per_step_aggregate)
percentile_info = compute_percentile_bands(curves, percentile_bands)
colors = None
if color_map is not None and norm is not None:
if color_by == "walker":
values = selected
else:
assert chain.log_prob is not None
values = chain.log_prob[step_index, selected]
colors = [color_map(norm(float(value))) for value in values]
log_prob_values = None
if chain.log_prob is not None:
log_prob_values = chain.log_prob[step_index, selected]
yield _FrameResult(
step_index=step_index,
walker_indices=selected,
curves=curves,
aggregate=aggregate,
percentile_bands=percentile_info,
colors=colors,
log_prob=log_prob_values,
)
[docs]
def animate_walkers(
model_fn: Callable[[Array1D, Array1D], Array1D],
t: ArrayLike,
chain: ArrayLike | Chain,
*,
log_prob: ArrayLike | None = None,
out: str | Path | None = None,
data_t: ArrayLike | None = None,
data_y: ArrayLike | None = None,
data_err: ArrayLike | None = None,
param_transform: Callable[[Array1D], Array1D]
| Callable[[Array2D], Array2D]
| None = None,
walkers: WalkerSelection = "all",
step_slice: StepSelection | None = None,
thin: int = 1,
topk_by_logp: int | None = None,
vectorized: bool = False,
n_jobs: int = 1,
chunk_size: int = 32,
percentile_bands: Sequence[float] | None = None,
per_step_aggregate: str | None = None,
color_by: str | None = None,
alpha: float = 0.15,
max_curves_per_frame: int | None = None,
fps: int = 15,
dpi: int = 120,
writer: str | mpl_animation.AbstractMovieWriter | None = None,
title: str | Callable[[int], str] | None = None,
y_label: str = "Model",
ylim: tuple[float, float] | None = None,
progress: bool = False,
) -> mpl_animation.FuncAnimation:
"""Animate walkers from an MCMC chain for a given model function.
Parameters
----------
model_fn : callable
Callable implementing ``f(theta, t) -> y``.
t : array_like
One-dimensional array of time samples.
chain : array_like or Chain
Either a raw chain array with shape ``(nsteps, nwalkers, ndim)`` or a
:class:`~flipbook.types.Chain` instance.
log_prob : array_like, optional
Log-probability values associated with ``chain``. Only required when the
chain object does not already include them.
out : str or Path, optional
Output path. When provided, the animation is saved to disk using the
requested writer. The animation object is still returned for further
manipulation.
data_t, data_y, data_err : array_like, optional
Observational data to overlay. ``data_err`` is interpreted as symmetric
uncertainties when supplied.
param_transform : callable, optional
Transformation applied to each walker prior to evaluating ``model_fn``.
walkers : {'all', int, sequence of int}, optional
Walker selection specification.
step_slice : slice, tuple, sequence of int, optional
Steps to include in the animation.
thin : int, optional
Frame thinning factor applied after ``step_slice``.
topk_by_logp : int, optional
If provided, restricts each frame to the top-K walkers by log
probability.
vectorized : bool, optional
Indicates that ``model_fn`` supports vectorized evaluation with an
``(nwalkers, ndim)`` parameter array.
n_jobs : int, optional
Number of threads to use for non-vectorized evaluation.
chunk_size : int, optional
Number of walkers evaluated together when ``n_jobs > 1``.
percentile_bands : sequence of float, optional
Percentile bands to shade in each frame.
per_step_aggregate : {'median', 'mean', None}, optional
Aggregate curve to highlight in each frame.
color_by : {'walker', 'logp', None}, optional
Strategy used to color individual walker curves.
alpha : float, optional
Base transparency applied to walker curves.
max_curves_per_frame : int, optional
Upper bound on the number of curves rendered per frame.
fps : int, optional
Target frames per second for the animation.
dpi : int, optional
Resolution when writing to disk.
writer : str or matplotlib writer, optional
Animation writer. When ``None`` a reasonable default is inferred from
the output file extension.
title : str or callable, optional
Static title string or callable ``title(step_index) -> str`` executed per
frame.
y_label : str, optional
Y-axis label for the plot.
ylim : tuple, optional
Y-axis limits.
progress : bool, optional
Display a progress bar while generating frames. Requires :mod:`tqdm`.
Returns
-------
matplotlib.animation.FuncAnimation
The generated animation object.
"""
if fps <= 0:
raise ValueError("fps must be a positive integer")
if alpha < 0 or alpha > 1:
raise ValueError("alpha must be within [0, 1]")
chain_obj = as_chain(chain, log_prob=log_prob)
t_array = np.asarray(t, dtype=float)
if t_array.ndim != 1:
raise ValueError("t must be a one-dimensional array")
nsteps, nwalkers = chain_obj.nsteps, chain_obj.nwalkers
step_indices = resolve_step_indices(nsteps, step_slice=step_slice, thin=thin)
walker_indices = resolve_walker_indices(nwalkers, walkers)
if walker_indices.size == 0:
raise ValueError("No walkers selected for animation")
if max_curves_per_frame is None:
max_artists = walker_indices.size
else:
max_artists = max_curves_per_frame
frame_iter = _frame_data_generator(
chain_obj,
model_fn,
t_array,
step_indices,
walker_indices,
param_transform=param_transform,
vectorized=vectorized,
chunk_size=chunk_size,
n_jobs=n_jobs,
topk_by_logp=topk_by_logp,
max_curves_per_frame=max_curves_per_frame,
per_step_aggregate=per_step_aggregate,
percentile_bands=percentile_bands,
color_by=color_by,
progress=progress,
)
fig, ax = plt.subplots()
ax.set_xlabel("t")
ax.set_ylabel(y_label)
ax.set_xlim(float(t_array.min()), float(t_array.max()))
if ylim is not None:
ax.set_ylim(*ylim)
if data_t is not None and data_y is not None:
data_t_array = np.asarray(data_t, dtype=float)
data_y_array = np.asarray(data_y, dtype=float)
if data_t_array.shape != data_y_array.shape:
raise ValueError("data_t and data_y must share the same shape")
if data_err is not None:
data_err_array = np.asarray(data_err, dtype=float)
if data_err_array.shape != data_y_array.shape:
raise ValueError("data_err must match the shape of data_y")
ax.errorbar(
data_t_array,
data_y_array,
yerr=data_err_array,
fmt="o",
color="black",
alpha=0.8,
label="data",
)
else:
ax.plot(data_t_array, data_y_array, "o", color="black", label="data")
walker_lines: list[Line2D] = [
ax.plot([], [], lw=1.2, alpha=alpha)[0] for _ in range(max_artists)
]
aggregate_line = None
if per_step_aggregate is not None:
(aggregate_line,) = ax.plot(
[], [], color="black", lw=2.0, label=per_step_aggregate
)
percentile_artists: dict[str, PolyCollection] = {}
def init() -> list:
artists: list = []
for line in walker_lines:
line.set_data([], [])
artists.append(line)
if aggregate_line is not None:
aggregate_line.set_data([], [])
artists.append(aggregate_line)
return artists
def update(frame: _FrameResult) -> list:
artists: list = []
curves = frame.curves
for idx, line in enumerate(walker_lines):
if idx < curves.shape[0]:
line.set_data(t_array, curves[idx])
if frame.colors is not None:
line.set_color(frame.colors[idx])
line.set_alpha(alpha)
else:
line.set_data([], [])
artists.append(line)
if aggregate_line is not None:
if frame.aggregate is None:
aggregate_line.set_data([], [])
else:
aggregate_line.set_data(t_array, frame.aggregate)
artists.append(aggregate_line)
for artist in percentile_artists.values():
artist.remove()
percentile_artists.clear()
for label, (lo, hi) in frame.percentile_bands.items():
band = ax.fill_between(t_array, lo, hi, alpha=0.2, label=label)
percentile_artists[label] = band
artists.append(band)
if callable(title):
ax.set_title(title(frame.step_index))
elif isinstance(title, str):
ax.set_title(title)
return artists
animation = mpl_animation.FuncAnimation(
fig,
update,
frames=frame_iter,
init_func=init,
blit=False,
interval=1000.0 / float(fps),
)
if out is not None:
output_path = Path(out)
resolved_writer = _resolve_writer(writer, output_path.suffix)
animation.save(str(output_path), writer=resolved_writer, dpi=dpi, fps=fps)
return animation
def _resolve_writer(
writer: str | mpl_animation.AbstractMovieWriter | None,
suffix: str,
) -> str | mpl_animation.AbstractMovieWriter:
"""Resolve the animation writer from user input and file suffix."""
if writer is not None:
if isinstance(writer, str) and not mpl_animation.writers.is_available(writer):
raise ValueError(f"Requested writer '{writer}' is not available")
return writer
suffix = suffix.lower()
if suffix in {".mp4", ".m4v", ".mov"}:
candidate = "ffmpeg"
elif suffix == ".gif":
candidate = "imagemagick"
else:
candidate = "pillow"
if not mpl_animation.writers.is_available(candidate):
raise ValueError(
"No suitable animation writer available. Install ffmpeg, imagemagick, or pillow."
)
return candidate
[docs]
def snapshot_step(
model_fn: Callable[[Array1D, Array1D], Array1D],
t: ArrayLike,
chain: ArrayLike | Chain,
*,
step: int,
walkers: WalkerSelection = "all",
log_prob: ArrayLike | None = None,
param_transform: Callable[[Array1D], Array1D]
| Callable[[Array2D], Array2D]
| None = None,
vectorized: bool = False,
n_jobs: int = 1,
chunk_size: int = 32,
percentile_bands: Sequence[float] | None = None,
per_step_aggregate: str | None = None,
color_by: str | None = None,
alpha: float = 0.15,
max_curves_per_frame: int | None = None,
title: str | None = None,
y_label: str = "Model",
ylim: tuple[float, float] | None = None,
) -> Figure:
"""Generate a static snapshot of a specific step from the chain.
Parameters
----------
model_fn : callable
Callable implementing ``f(theta, t) -> y``.
t : array_like
One-dimensional array of time samples.
chain : array_like or Chain
Chain data or :class:`~flipbook.types.Chain` instance.
step : int
Step index to visualise.
walkers : {'all', int, sequence of int}, optional
Walker selection specification.
log_prob : array_like, optional
Log-probability values for the chain.
param_transform : callable, optional
Optional transformation applied prior to model evaluation.
vectorized : bool, optional
Indicates that ``model_fn`` supports vectorized evaluation.
n_jobs : int, optional
Number of worker threads for non-vectorized evaluation.
chunk_size : int, optional
Batch size for threaded evaluation.
percentile_bands : sequence of float, optional
Percentile bands to shade in the snapshot.
per_step_aggregate : {'median', 'mean', None}, optional
Aggregate curve to highlight.
color_by : {'walker', 'logp', None}, optional
Strategy used to colour individual walker curves.
alpha : float, optional
Transparency applied to walker curves.
max_curves_per_frame : int, optional
Upper bound on the number of curves rendered.
title : str, optional
Title for the generated figure.
y_label : str, optional
Y-axis label.
ylim : tuple, optional
Y-axis limits.
Returns
-------
matplotlib.figure.Figure
Matplotlib figure containing the snapshot.
"""
if alpha < 0 or alpha > 1:
raise ValueError("alpha must be within [0, 1]")
chain_obj = as_chain(chain, log_prob=log_prob)
t_array = np.asarray(t, dtype=float)
if t_array.ndim != 1:
raise ValueError("t must be a one-dimensional array")
step_indices = np.asarray([int(step)], dtype=int)
walker_indices = resolve_walker_indices(chain_obj.nwalkers, walkers)
frame_iter = _frame_data_generator(
chain_obj,
model_fn,
t_array,
step_indices,
walker_indices,
param_transform=param_transform,
vectorized=vectorized,
chunk_size=chunk_size,
n_jobs=n_jobs,
topk_by_logp=None,
max_curves_per_frame=max_curves_per_frame,
per_step_aggregate=per_step_aggregate,
percentile_bands=percentile_bands,
color_by=color_by,
progress=False,
)
frame = next(frame_iter)
fig, ax = plt.subplots()
ax.set_xlabel("t")
ax.set_ylabel(y_label)
ax.set_xlim(float(t_array.min()), float(t_array.max()))
if ylim is not None:
ax.set_ylim(*ylim)
for idx, curve in enumerate(frame.curves):
color = frame.colors[idx] if frame.colors is not None else None
ax.plot(t_array, curve, color=color, alpha=alpha)
if frame.aggregate is not None:
ax.plot(t_array, frame.aggregate, color="black", lw=2.0)
for label, (lo, hi) in frame.percentile_bands.items():
ax.fill_between(t_array, lo, hi, alpha=0.2, label=label)
if title is not None:
ax.set_title(title)
return fig
[docs]
def precompute_curves(
model_fn: Callable[[Array1D, Array1D], Array1D],
t: ArrayLike,
chain: ArrayLike | Chain,
*,
steps: StepSelection | None = None,
walkers: WalkerSelection = "all",
log_prob: ArrayLike | None = None,
param_transform: Callable[[Array1D], Array1D]
| Callable[[Array2D], Array2D]
| None = None,
vectorized: bool = False,
n_jobs: int = 1,
chunk_size: int = 32,
topk_by_logp: int | None = None,
max_curves_per_frame: int | None = None,
) -> Iterator[dict[str, object]]:
"""Pre-compute model curves for later use.
Parameters
----------
model_fn : callable
Callable implementing ``f(theta, t) -> y``.
t : array_like
One-dimensional array of time samples.
chain : array_like or Chain
Chain data or :class:`~flipbook.types.Chain` instance.
steps : slice, tuple, sequence of int, optional
Steps to pre-compute. When ``None`` all steps are considered.
walkers : {'all', int, sequence of int}, optional
Walker selection specification.
log_prob : array_like, optional
Log-probability values for the chain.
param_transform : callable, optional
Optional transformation applied prior to model evaluation.
vectorized : bool, optional
Indicates that ``model_fn`` supports vectorized evaluation.
n_jobs : int, optional
Number of worker threads for non-vectorized evaluation.
chunk_size : int, optional
Batch size for threaded evaluation.
topk_by_logp : int, optional
If provided, restricts walkers to the top-K by log probability per step.
max_curves_per_frame : int, optional
Maximum number of curves retained per step.
Returns
-------
generator of dict
Generator yielding dictionaries with keys ``'step_index'``,
``'walker_indices'``, ``'curves'``, and ``'log_prob'``.
"""
chain_obj = as_chain(chain, log_prob=log_prob)
t_array = np.asarray(t, dtype=float)
if t_array.ndim != 1:
raise ValueError("t must be a one-dimensional array")
step_indices = resolve_step_indices(chain_obj.nsteps, step_slice=steps, thin=1)
walker_indices = resolve_walker_indices(chain_obj.nwalkers, walkers)
frame_iter = _frame_data_generator(
chain_obj,
model_fn,
t_array,
step_indices,
walker_indices,
param_transform=param_transform,
vectorized=vectorized,
chunk_size=chunk_size,
n_jobs=n_jobs,
topk_by_logp=topk_by_logp,
max_curves_per_frame=max_curves_per_frame,
per_step_aggregate=None,
percentile_bands=None,
color_by=None,
progress=False,
)
def generator() -> Iterator[dict[str, object]]:
for frame in frame_iter:
yield {
"step_index": frame.step_index,
"walker_indices": frame.walker_indices,
"curves": frame.curves,
"log_prob": frame.log_prob,
}
return generator()
[docs]
def animate_from_emcee(
model_fn: Callable[[Array1D, Array1D], Array1D],
t: ArrayLike,
sampler: object,
*,
out: str | Path | None = None,
**kwargs: Any,
) -> mpl_animation.FuncAnimation:
"""Animate walkers directly from an :mod:`emcee` sampler.
Parameters
----------
model_fn : callable
Callable implementing ``f(theta, t) -> y``.
t : array_like
One-dimensional array of time samples.
sampler : object
``emcee`` sampler providing ``get_chain`` and optionally ``get_log_prob``.
out : str or Path, optional
Output filename passed through to :func:`animate_walkers`.
**kwargs
Additional keyword arguments forwarded to :func:`animate_walkers`.
Returns
-------
matplotlib.animation.FuncAnimation
The generated animation object.
Raises
------
TypeError
If ``sampler`` does not provide the expected ``get_chain`` method.
"""
if not hasattr(sampler, "get_chain"):
raise TypeError("sampler must provide a 'get_chain' method")
chain_array = sampler.get_chain()
log_prob_array = None
if hasattr(sampler, "get_log_prob"):
log_prob_array = sampler.get_log_prob()
return animate_walkers(
model_fn,
t,
chain_array,
log_prob=log_prob_array,
out=out,
**kwargs,
)