Open In Colab

02 — CRT demo

LV/RV trees, activation, and BiV scenarios

This demo builds LV and RV Purkinje trees on the provided endocardial surfaces, runs a fast activation solver, and visualizes activation on the chambers for several scenarios:

  • baseline: LV and RV start together

  • rv_only: only RV is paced (LV is unreached → masked)

  • biv: BiV with a user‐controlled VV delay (VV_DELAY_MS)

  • biv_paper: BiV with the paper’s LV-early preset (−75 ms)

We use the paper ground-truth seed indices and parameters to reproduce the reference trees.

Environment & dependencies

This notebook is designed to run in an isolated environment (e.g., Colab).

  • Installs: purkinje-uv and plotly

  • No GPU is required.

  • Output is written under output/examples/02_crt_demo/.

[ ]:
%pip install -q --upgrade pip purkinje-uv plotly
[ ]:
from pathlib import Path
import os
import json
import numpy as np

from purkinje_uv import FractalTreeParameters, FractalTree, PurkinjeTree
import pyvista as pv
import plotly.graph_objects as go

Reproducibility knobs & paths

  • SEED controls NumPy’s RNG (tree growth is deterministic given parameters; this is for ancillary steps).

  • LITE=1 uses quicker defaults when we generate parameters automatically (we keep the paper preset here).

  • VV_DELAY_MS controls the biv scenario: LV is delayed by this many ms relative to RV.

[ ]:
# Repro + knobs
SEED = int(os.getenv("EXAMPLES_SEED", "1234"))
LITE = bool(int(os.getenv("EXAMPLES_LITE", "1")))          # fast by default
VV_DELAY_MS = int(os.getenv("CRT_VV_DELAY_MS", "0"))       # BiV LV delay vs RV (ms); try -40, 0, +40 later

# Locations
DATA_DIR = Path("data") / "crtdemo"
OUT_DIR = Path("output") / "examples" / "02_crt_demo"
OUT_DIR.mkdir(parents=True, exist_ok=True)

np.random.seed(SEED)
print(f"SEED={SEED}  LITE={LITE}  VV_DELAY_MS={VV_DELAY_MS}")
print("DATA_DIR:", DATA_DIR)
print("OUT_DIR:", OUT_DIR)

Mesh assets

We use the two open-surface endocardial meshes:

  • crtdemo_LVendo_heart_cut.obj

  • crtdemo_RVendo_heart_cut.obj

If missing locally, they are downloaded from the repo. We load them via PyVista, clean, and triangulate.

[ ]:
# Ensure demo OBJ files exist (download from repo if missing)
import urllib.request

LV_OBJ = DATA_DIR / "crtdemo_LVendo_heart_cut.obj"
RV_OBJ = DATA_DIR / "crtdemo_RVendo_heart_cut.obj"
DATA_DIR.mkdir(parents=True, exist_ok=True)

def _try_fetch(url: str, dst: Path):
    try:
        print("Downloading:", url)
        with urllib.request.urlopen(url, timeout=30) as r, open(dst, "wb") as f:
            f.write(r.read())
        return True
    except Exception as e:
        print("  failed:", e)
        return False

if not LV_OBJ.exists():
    for base in [
        "https://raw.githubusercontent.com/ricardogr07/purkinje-uv/main/data/crtdemo/",
        "https://raw.githubusercontent.com/ricardogr07/purkinje-uv/feat/create-examples-tutorials/data/crtdemo/",
    ]:
        if _try_fetch(base + LV_OBJ.name, LV_OBJ):
            break

if not RV_OBJ.exists():
    for base in [
        "https://raw.githubusercontent.com/ricardogr07/purkinje-uv/main/data/crtdemo/",
        "https://raw.githubusercontent.com/ricardogr07/purkinje-uv/feat/create-examples-tutorials/data/crtdemo/",
    ]:
        if _try_fetch(base + RV_OBJ.name, RV_OBJ):
            break

assert LV_OBJ.exists(), f"Missing LV OBJ: {LV_OBJ}"
assert RV_OBJ.exists(), f"Missing RV OBJ: {RV_OBJ}"
print("LV_OBJ:", LV_OBJ)
print("RV_OBJ:", RV_OBJ)

# Load with PyVista and verify they're open surfaces
def load_surface(path: Path) -> pv.PolyData:
    mesh = pv.read(str(path))
    if not isinstance(mesh, pv.PolyData):
        mesh = mesh.extract_geometry()
    mesh = mesh.clean().triangulate()
    return mesh

surf_lv = load_surface(LV_OBJ)
surf_rv = load_surface(RV_OBJ)

def is_open_surface(pd: pv.PolyData) -> bool:
    edges = pd.extract_feature_edges(boundary_edges=True, feature_edges=False,
                                     manifold_edges=False, non_manifold_edges=True)
    return edges.n_cells > 0

print(f"LV: points={surf_lv.n_points} faces={surf_lv.n_cells} open={is_open_surface(surf_lv)}")
print(f"RV: points={surf_rv.n_points} faces={surf_rv.n_cells} open={is_open_surface(surf_rv)}")

Paper ground-truth seeds & parameters

We follow the published demo by using fixed vertex indices for the first two nodes per chamber and the same parameter values:

  • Seeds (0-based): LV (388, 412), RV (198, 186)

  • Key parameters (per demo): N_it=20, branch_angle=0.15, l_segment=1.0, length=8.0, w=0.1, plus small chamber-specific fascicle angles/lengths and init_length.

To avoid “initial branch goes out of the domain”, we retry with a shorter init_length automatically.

[ ]:
# --- Paper ground-truth preset (from PurkinjeECG demo) ---

from dataclasses import asdict

# 0-based vertex indices on the demo meshes (LV: 441 pts, RV: 817 pts in our copy)
LV_SEEDS = (388, 412)
RV_SEEDS = (198, 186)

# Ground-truth values (already scaled in the demo: 0.5 on lengths, 0.1 on angles)
LV_INIT_LENGTH = float(35.931537038275316)
RV_INIT_LENGTH = float(79.86354832236707)

LV_FAS_LEN = [float(0.5*4.711579058738858), float(0.5*9.129484609771032)]
RV_FAS_LEN = [float(0.5*21.703867933650002), float(0.5*5.79561866201451)]

LV_FAS_ANG = [float(0.1*0.14448952070696136), float(0.1*0.23561944901923448)]
RV_FAS_ANG = [float(0.1*0.23561944901923448), float(0.1*0.23561944901923448)]

COMMON = dict(
    length=8.0,          # step length along branches
    w=0.1,               # lateral bias
    l_segment=1.0,       # base segment length
    branch_angle=0.15,   # split angle
    N_it=20              # depth
)

def clamp_seeds(surf, seeds):
    n = surf.n_points
    return tuple(int(max(0, min(n-1, s))) for s in seeds)

LV_SEEDS = clamp_seeds(surf_lv, LV_SEEDS)
RV_SEEDS = clamp_seeds(surf_rv, RV_SEEDS)

def build_params_paper(chamber, meshfile, seeds):
    if chamber == "LV":
        init_len = LV_INIT_LENGTH
        fas_len  = LV_FAS_LEN
        fas_ang  = LV_FAS_ANG
    else:
        init_len = RV_INIT_LENGTH
        fas_len  = RV_FAS_LEN
        fas_ang  = RV_FAS_ANG
    return FractalTreeParameters(
        meshfile=str(meshfile),
        init_node_id=seeds[0],
        second_node_id=seeds[1],
        init_length=init_len,
        length=COMMON["length"],
        w=COMMON["w"],
        l_segment=COMMON["l_segment"],
        fascicles_length=fas_len,
        fascicles_angles=fas_ang,
        branch_angle=COMMON["branch_angle"],
        N_it=COMMON["N_it"],
    )

def with_init_length(p: FractalTreeParameters, new_init):
    # dataclass is frozen; rebuild with the updated init_length
    d = asdict(p)
    d["init_length"] = float(new_init)
    return FractalTreeParameters(**d)

def grow_with_retry(params, scales=(1.0, 0.75, 0.5, 0.35, 0.25)):
    last_err = None
    for s in scales:
        p_try = with_init_length(params, params.init_length * s)
        try:
            ft = FractalTree(params=p_try)
            ft.grow_tree()
            return ft, p_try
        except RuntimeError as e:
            last_err = e
            if "out of the domain" in str(e):
                continue
            raise
    raise RuntimeError(f"Failed to grow even after shrinking init_length. Last error: {last_err}")

# Build params and grow using the paper preset
params_lv = build_params_paper("LV", LV_OBJ, LV_SEEDS)
params_rv = build_params_paper("RV", RV_OBJ, RV_SEEDS)
print("Paper preset params (LV):", params_lv)
print("Paper preset params (RV):", params_rv)

Grow LV/RV trees and export

After growth we report node/edge/PMJ counts and export:

  • lv_tree_AT.vtu, rv_tree_AT.vtu — trees with activation arrays

  • lv_pmj.vtu, rv_pmj.vtu — PMJ locations

A quick sanity activation is run per chamber (root @ 0 s) to verify ranges.

[ ]:
ft_lv, params_lv = grow_with_retry(params_lv)
ft_rv, params_rv = grow_with_retry(params_rv)

nodes_lv, edges_lv, pmj_lv = np.asarray(ft_lv.nodes_xyz), np.asarray(ft_lv.connectivity), np.asarray(ft_lv.end_nodes)
nodes_rv, edges_rv, pmj_rv = np.asarray(ft_rv.nodes_xyz), np.asarray(ft_rv.connectivity), np.asarray(ft_rv.end_nodes)

print(f"LV: nodes={len(nodes_lv)} edges={len(edges_lv)} pmj={len(pmj_lv)}")
print(f"RV: nodes={len(nodes_rv)} edges={len(edges_rv)} pmj={len(pmj_rv)}")
[ ]:
def solve_activation(P: PurkinjeTree, sources, unit="s"):
    x0 = np.array([s[0] for s in sources], dtype=int)
    t  = np.array([float(s[1]) for s in sources], dtype=float)
    if unit == "ms":
        t = t / 1000.0
    return P.activate_fim(x0=x0, x0_vals=t, return_only_pmj=False)
[ ]:
# Build PurkinjeTree objects from these ground-truth trees
P_lv = PurkinjeTree(nodes_lv, edges_lv, pmj_lv)
P_rv = PurkinjeTree(nodes_rv, edges_rv, pmj_rv)

# quick baseline sanity
AT_lv_baseline = solve_activation(P_lv, sources=[(0, 0.0)])
AT_rv_baseline = solve_activation(P_rv, sources=[(0, 0.0)])
print("Baseline sanity — LV AT(min/max):", float(AT_lv_baseline.min()), float(AT_lv_baseline.max()))
print("Baseline sanity — RV AT(min/max):", float(AT_rv_baseline.min()), float(AT_rv_baseline.max()))

# (optional: rename *_gt.vtu if you like)
lv_tree_path = OUT_DIR / "lv_tree_AT.vtu"
rv_tree_path = OUT_DIR / "rv_tree_AT.vtu"
lv_pmj_path  = OUT_DIR / "lv_pmj.vtu"
rv_pmj_path  = OUT_DIR / "rv_pmj.vtu"

P_lv.save(str(lv_tree_path)); P_lv.save_pmjs(str(lv_pmj_path))
P_rv.save(str(rv_tree_path)); P_rv.save_pmjs(str(rv_pmj_path))

print("Wrote:")
print(" -", lv_tree_path)
print(" -", rv_tree_path)
print(" -", lv_pmj_path)
print(" -", rv_pmj_path)

Merge & activation scenarios

We merge LV and RV into a single disconnected tree (LV nodes first, then RV) and define:

  • baseline: [(LV, 0.0), (RV, 0.0)]

  • rv_only: [(RV, 0.0)] → LV becomes unreached (we mask huge times)

  • biv: [(RV, 0.0), (LV, VV_DELAY_MS/1000)]

  • biv_paper: [(LV, 0.0), (RV, 0.075)] for ROOT_TIME_MS = -75

Units: the solver expects seconds; we convert milliseconds as needed.

[ ]:
# --- Helpers + Merge LV/RV into a single (disconnected) tree ---

INF_CUTOFF = 1e6  # treat >= this as "unreached" for display

def merge_trees(nodes_a, edges_a, pmj_a, nodes_b, edges_b, pmj_b):
    nA = nodes_a.shape[0]
    nodes = np.vstack([nodes_a, nodes_b])
    edges = np.vstack([edges_a, edges_b + nA])
    pmj   = np.concatenate([pmj_a, pmj_b + nA])
    return nodes, edges, pmj

def build_P(nodes, edges, pmj) -> PurkinjeTree:
    return PurkinjeTree(
        nodes=np.asarray(nodes, dtype=float),
        connectivity=np.asarray(edges, dtype=int),
        end_nodes=np.asarray(pmj, dtype=int)
    )

def nearest_map_activation(surface: pv.PolyData, nodes_xyz: np.ndarray, AT: np.ndarray, chunk=2000):
    """Nearest-neighbor map of node AT to each surface vertex."""
    pts = surface.points
    out = np.empty(pts.shape[0], dtype=float)
    for i in range(0, pts.shape[0], chunk):
        j = min(i + chunk, pts.shape[0])
        d2 = np.sum((pts[i:j, None, :] - nodes_xyz[None, :, :])**2, axis=2)
        out[i:j] = AT[np.argmin(d2, axis=1)]
    s = surface.copy(deep=True)
    s.point_data.clear()
    s.point_data["AT"] = out
    return s
[ ]:
# Merge
nodes_biv, edges_biv, pmj_biv = merge_trees(nodes_lv, edges_lv, pmj_lv, nodes_rv, edges_rv, pmj_rv)
P_biv = build_P(nodes_biv, edges_biv, pmj_biv)

# Root indices in merged space (LV block first, then RV)
n_lv = nodes_lv.shape[0]
idx_lv_root = 0
idx_rv_root = n_lv

print("Merged tree:", nodes_biv.shape, edges_biv.shape, pmj_biv.shape)
print("Root indices → LV:", idx_lv_root, "RV:", idx_rv_root)
[ ]:
# --- Scenarios: baseline, rv_only, biv (uses VV_DELAY_MS from above) ---

def finite_minmax(a):
    f = np.isfinite(a) & (a < INF_CUTOFF)
    return (float(np.nanmin(a[f])) if np.any(f) else np.nan), (float(np.nanmax(a[f])) if np.any(f) else np.nan)

scenarios = {}

# both roots @ 0
scenarios["baseline"] = solve_activation(P_biv, [(idx_lv_root, 0.0), (idx_rv_root, 0.0)])

# RV only (LV should be "unreached" → very large)
scenarios["rv_only"]  = solve_activation(P_biv, [(idx_rv_root, 0.0)])

# BiV: RV @ 0, LV @ delay (ms → s)
scenarios["biv"]      = solve_activation(P_biv, [(idx_rv_root, 0.0), (idx_lv_root, VV_DELAY_MS/1000.0)])


# BiV with paper-style VV offset (LV early by 75 ms if negative)
ROOT_TIME_MS = -75.0  # LV earlier by 75 ms (paper demo)
src = ([(idx_lv_root, 0.0), (idx_rv_root, abs(ROOT_TIME_MS)/1000.0)]
       if ROOT_TIME_MS < 0
       else [(idx_rv_root, 0.0), (idx_lv_root, abs(ROOT_TIME_MS)/1000.0)])
scenarios["biv_paper"] = solve_activation(P_biv, src)

for name, AT in scenarios.items():
    lo, hi = finite_minmax(AT)
    print(f"{name}: AT finite min/max = {lo:.4f} / {hi:.4f}")

Map activation to surfaces & save VTP

Activation at nodes is transferred to the LV/RV surfaces via nearest neighbor and written as:

  • lv_surface_AT_<scenario>.vtp

  • rv_surface_AT_<scenario>.vtp

Mapped scalar name: ``AT``. Extremely large values (unreached) are replaced with NaN to keep colorbars useful.

[ ]:
# --- Map node AT → LV/RV surfaces and export VTP per scenario ---

def export_surface_AT(name: str, AT_biv: np.ndarray):
    # split per chamber
    AT_lv = AT_biv[:n_lv]
    AT_rv = AT_biv[n_lv:]
    # map to surfaces
    lv_surf = nearest_map_activation(surf_lv, nodes_lv, AT_lv)
    rv_surf = nearest_map_activation(surf_rv, nodes_rv, AT_rv)
    # optionally mask huge times to NaN for nicer colorbars in viewers
    for s in (lv_surf, rv_surf):
        if "AT" in s.point_data:
            at = s.point_data["AT"]
            at = np.where(np.isfinite(at) & (at < INF_CUTOFF), at, np.nan)
            s.point_data["AT"] = at
    # save
    lv_path = OUT_DIR / f"lv_surface_AT_{name}.vtp"
    rv_path = OUT_DIR / f"rv_surface_AT_{name}.vtp"
    lv_surf.save(str(lv_path))
    rv_surf.save(str(rv_path))
    print("Saved:", lv_path.name, "|", rv_path.name)

for scen, AT in scenarios.items():
    export_surface_AT(scen, AT)
[ ]:
# --- Save the merged tree geometry per scenario (same geometry each time) ---

def save_tree_geom(name: str):
    Ptmp = build_P(nodes_biv, edges_biv, pmj_biv)
    path = OUT_DIR / f"biv_tree_{name}.vtu"
    Ptmp.save(str(path))
    print("Saved:", path.name)

for scen in scenarios.keys():
    save_tree_geom(scen)

Interactive 3D viewer (Plotly)

Use the dropdown to switch scenarios. Layers include:

  • LV/RV surfaces colored by AT (semi-transparent)

  • LV/RV tree polylines (bold, high contrast)

  • Node points colored by AT

  • PMJs (squares)

Tips:

  • Click legend entries to toggle layers.

  • Adjust styling at the top of the cell:

    • SURF_OPACITY (e.g., 0.08–0.15)

    • TREE_WIDTH (e.g., 3–9)

    • NODE_SIZE, PMJ_SIZE

[ ]:
# --- Plotly viewer with scenario dropdown (Colab-friendly, high-contrast) ---

SURF_OPACITY = 0.12
TREE_WIDTH   = 3
NODE_SIZE    = 5
PMJ_SIZE     = 4
LV_TREE_COL  = "#111111"
RV_TREE_COL  = "#2b6cb0"


def mesh3d_with_AT(surface: pv.PolyData, name):
    pts = surface.points
    faces = surface.faces.reshape(-1, 4)[:, 1:]
    at = surface.point_data.get("AT", None)
    return go.Mesh3d(
        x=pts[:,0], y=pts[:,1], z=pts[:,2],
        i=faces[:,0], j=faces[:,1], k=faces[:,2],
        name=name, opacity=SURF_OPACITY,
        intensity=at if at is not None else None,
        colorscale="Viridis",
        showscale=True if at is not None else False,
        colorbar=dict(title="AT") if at is not None else None,
        lighting=dict(ambient=0.6, diffuse=0.6)
    )

def line_segments(points: np.ndarray, edges: np.ndarray, name, color, width=TREE_WIDTH):
    xs, ys, zs = [], [], []
    for u, v in edges:
        xs += [points[u,0], points[v,0], None]
        ys += [points[u,1], points[v,1], None]
        zs += [points[u,2], points[v,2], None]
    return go.Scatter3d(x=xs, y=ys, z=zs, mode="lines",
                        line=dict(width=width, color=color),
                        name=name, showlegend=True)

def node_points(points: np.ndarray, at: np.ndarray, name):
    c = np.where(np.isfinite(at) & (at < INF_CUTOFF), at, np.nan)
    return go.Scatter3d(
        x=points[:,0], y=points[:,1], z=points[:,2],
        mode="markers", name=name, showlegend=True,
        marker=dict(size=NODE_SIZE, color=c, colorscale="Viridis",
                    line=dict(width=0.8, color="white"),
                    opacity=0.98, showscale=False)
    )

def pmj_points(points: np.ndarray, name, color):
    if points.size == 0:
        return go.Scatter3d(x=[], y=[], z=[], mode="markers", name=name)
    return go.Scatter3d(
        x=points[:,0], y=points[:,1], z=points[:,2],
        mode="markers", name=name, showlegend=True,
        marker=dict(size=PMJ_SIZE, color=color, symbol="square", opacity=0.95)
    )

# Prepare per-scenario surfaces (already saved above, but we’ll compute in-memory again)
surfaces_by_scen = {}
ats_nodes_by_scen = {}
for scen, AT in scenarios.items():
    AT_lv = AT[:n_lv]
    AT_rv = AT[n_lv:]
    lv_surf = nearest_map_activation(surf_lv, nodes_lv, AT_lv)
    rv_surf = nearest_map_activation(surf_rv, nodes_rv, AT_rv)
    # mask huge for display
    for s in (lv_surf, rv_surf):
        at = s.point_data["AT"]
        s.point_data["AT"] = np.where(np.isfinite(at) & (at < INF_CUTOFF), at, np.nan)
    surfaces_by_scen[scen] = (lv_surf, rv_surf)
    ats_nodes_by_scen[scen] = (AT_lv, AT_rv)

pmj_coords_lv = nodes_lv[pmj_lv] if pmj_lv.size else np.empty((0,3))
pmj_coords_rv = nodes_rv[pmj_rv] if pmj_rv.size else np.empty((0,3))

# scenario order
ordered   = ["baseline", "rv_only", "biv", "biv_paper"]
scen_list = [s for s in ordered if s in scenarios] or list(scenarios.keys())


# build traces
traces, masks = [], []
for scen in scen_list:
    lv_surf, rv_surf = surfaces_by_scen[scen]
    AT_lv, AT_rv = ats_nodes_by_scen[scen]

    traces += [
        mesh3d_with_AT(lv_surf, "LV surface"),
        mesh3d_with_AT(rv_surf, "RV surface"),
        line_segments(nodes_lv, edges_lv, "LV tree", LV_TREE_COL),
        line_segments(nodes_rv, edges_rv, "RV tree", RV_TREE_COL),
        node_points(nodes_lv, AT_lv, "LV nodes (AT)"),
        node_points(nodes_rv, AT_rv, "RV nodes (AT)"),
        pmj_points(pmj_coords_lv, "LV PMJs", "crimson"),
        pmj_points(pmj_coords_rv, "RV PMJs", "orange"),
    ]
    mask = [False] * (8 * len(scen_list))
    base = scen_list.index(scen) * 8
    for k in range(base, base+8): mask[k] = True
    masks.append(mask)

fig = go.Figure(data=traces)
# show first scenario
if masks:
    for tr, vis in zip(fig.data, masks[0]):
        tr.visible = vis

# dropdown
fig.update_layout(
    updatemenus=[dict(type="dropdown",
                      x=0.01, y=0.99, showactive=True,
                      buttons=[dict(method="update", label=sc,
                                    args=[{"visible": masks[i]}])
                               for i, sc in enumerate(scen_list)])],
    scene=dict(aspectmode="data"),
    margin=dict(l=0, r=0, t=30, b=0),
    title=f"CRT Demo — scenarios: {', '.join(scen_list)}"
)
fig.show()

Tuning & variants

  • Want denser coverage? Increase N_it or length (step) slightly.

  • Shift VV timing: set VV_DELAY_MS at the top (e.g., -40, 0, +40).

  • Compare to paper timing: use biv_paper (LV early by 75 ms).

  • Different seeds: replace LV_SEEDS / RV_SEEDS or switch to auto-seeding.

Troubleshooting

  • HTTP 404 on OBJ: ensure the data branch is reachable; the notebook will try both main and the examples branch.

  • “Initial branch goes out of the domain”: the notebook will retry with a shorter init_length. If it persists, pick a different seed pair.

  • Flat colorbar in ``rv_only``: expected — unreached LV is masked to NaN.

  • Slow mapping: reduce vertex count of the surface or increase batch size in nearest_map_activation.