Source code for pasted._generator

"""
pasted._generator
=================
High-level API:

- :class:`Structure`          — dataclass holding one generated structure.
- :class:`StructureGenerator` — stateful generator (class API).
- :func:`generate`            — convenience functional wrapper.
"""

from __future__ import annotations

import random
import sys
import warnings
from collections import Counter
from collections.abc import Iterator
from dataclasses import dataclass, field
from pathlib import Path

from ._atoms import (
    _Z_TO_SYM,
    ATOMIC_NUMBERS,
    _cov_radius_ang,
    default_element_pool,
    parse_element_spec,
    parse_filter,
    validate_charge_mult,
)
from ._io import _fmt, format_xyz, parse_xyz
from ._metrics import compute_all_metrics, passes_filters
from ._placement import (
    Vec3,
    add_hydrogen,
    place_chain,
    place_gas,
    place_maxent,
    place_shell,
    relax_positions,
)

# ---------------------------------------------------------------------------
# Structure dataclass
# ---------------------------------------------------------------------------


[docs] @dataclass class Structure: """A single generated atomic structure with its computed disorder metrics. Attributes ---------- atoms: Element symbols, one per atom. positions: Cartesian coordinates in Å, one ``(x, y, z)`` tuple per atom. charge: Total system charge. mult: Spin multiplicity 2S+1. metrics: Computed disorder metrics (see :data:`pasted._atoms.ALL_METRICS`). mode: Placement mode used (``"gas"``, ``"chain"``, or ``"shell"``). sample_index: 1-based index within the batch of structures that passed filters. center_sym: Element symbol of the shell center atom (shell mode only). seed: Random seed used for generation (``None`` if unseeded). """ atoms: list[str] positions: list[Vec3] charge: int mult: int metrics: dict[str, float] mode: str sample_index: int = 0 center_sym: str | None = None seed: int | None = None # ------------------------------------------------------------------ # # XYZ output # # ------------------------------------------------------------------ #
[docs] def to_xyz(self, prefix: str = "") -> str: """Serialise to extended XYZ format. Parameters ---------- prefix: Custom prefix for the comment line. When omitted the standard ``"sample=N mode=M …"`` string is generated automatically. Returns ------- Multi-line string (no trailing newline). """ if not prefix: prefix = f"sample={self.sample_index} mode={self.mode}" if self.mode == "shell" and self.center_sym: prefix += f" center={self.center_sym}(Z={ATOMIC_NUMBERS[self.center_sym]})" if self.seed is not None: prefix += f" seed={self.seed}" return format_xyz( self.atoms, self.positions, self.charge, self.mult, self.metrics, prefix, )
[docs] def write_xyz(self, path: str | Path, *, append: bool = True) -> None: """Write this structure to an XYZ file. Parameters ---------- path: Output file path. append: If ``True`` (default) the file is opened in append mode so that multiple structures can be written in sequence. Use ``append=False`` to overwrite. """ mode = "a" if append else "w" with Path(path).open(mode) as fh: fh.write(self.to_xyz() + "\n")
# ------------------------------------------------------------------ # # Dunder helpers # # ------------------------------------------------------------------ #
[docs] def __len__(self) -> int: return len(self.atoms)
# ------------------------------------------------------------------ # # XYZ import # # ------------------------------------------------------------------ #
[docs] @classmethod def from_xyz( cls, source: str | Path, *, frame: int = 0, recompute_metrics: bool = True, cutoff: float | None = None, n_bins: int = 20, w_atom: float = 0.5, w_spatial: float = 0.5, cov_scale: float = 1.0, ) -> Structure: """Load a :class:`Structure` from an XYZ file or string. Supports both plain XYZ and PASTED extended XYZ (with ``charge=``, ``mult=``, and metric tokens on the comment line). When *recompute_metrics* is ``True`` (default), all disorder metrics are recomputed from the loaded geometry so that the returned structure is fully usable as optimizer input or for filtering. Parameters ---------- source: Path to an XYZ file **or** a raw XYZ string. frame: Zero-based frame index when *source* contains multiple concatenated structures (default: first frame). recompute_metrics: Recompute all disorder metrics after loading. Set to ``False`` to skip the recomputation and return the structure with whatever metric values were embedded in the extended XYZ comment (or an empty dict for plain XYZ). cutoff: Distance cutoff (Å) for metric computation. Auto-computed from the element pool when ``None``. n_bins: Histogram bins for ``H_spatial`` / ``RDF_dev`` (default: 20). w_atom: Weight of ``H_atom`` in ``H_total`` (default: 0.5). w_spatial: Weight of ``H_spatial`` in ``H_total`` (default: 0.5). cov_scale: Minimum distance scale factor used for metrics (default: 1.0). Returns ------- Structure Raises ------ ValueError When the file / string cannot be parsed, or *frame* is out of range. Examples -------- Load and immediately use as optimizer initial structure:: from pasted import Structure, StructureOptimizer s = Structure.from_xyz("my_structure.xyz") opt = StructureOptimizer( n_atoms=len(s), charge=s.charge, mult=s.mult, objective={"H_total": 1.0}, elements=[sym for sym in set(s.atoms)], max_steps=2000, seed=42, ) result = opt.run(initial=s) """ p = Path(source) if not isinstance(source, str) or "\n" not in str(source) else None if p is not None and p.exists(): text = p.read_text() else: text = str(source) frames = parse_xyz(text) if not frames: raise ValueError("No frames found in XYZ source.") if frame < 0 or frame >= len(frames): raise ValueError( f"frame={frame} out of range; source contains {len(frames)} frame(s)." ) atoms, positions, charge, mult, embedded_metrics = frames[frame] if recompute_metrics: if cutoff is None: radii = [_cov_radius_ang(a) for a in atoms] pair_sums = sorted( ra + rb for i, ra in enumerate(radii) for rb in radii[i:] ) median_sum = pair_sums[len(pair_sums) // 2] cutoff = cov_scale * 1.5 * median_sum metrics = compute_all_metrics( atoms, positions, n_bins, w_atom, w_spatial, cutoff, cov_scale ) else: metrics = embedded_metrics return cls( atoms=list(atoms), positions=list(positions), charge=charge, mult=mult, metrics=metrics, mode="loaded_xyz", )
[docs] def __repr__(self) -> str: counts = Counter(self.atoms) comp = "".join(f"{sym}{n}" if n > 1 else sym for sym, n in sorted(counts.items())) h_total = self.metrics.get("H_total", float("nan")) return f"Structure(n={len(self)}, comp={comp!r}, mode={self.mode!r}, H_total={h_total:.3f})"
# --------------------------------------------------------------------------- # GenerationResult # ---------------------------------------------------------------------------
[docs] @dataclass class GenerationResult: """Return value of :func:`generate` and :meth:`StructureGenerator.generate`. Behaves like a ``list[Structure]`` in all normal usage (indexing, iteration, ``len``, boolean test, ``for s in result``) while also carrying metadata about how many attempts were made and why samples were rejected. This metadata is especially useful when integrating PASTED into automated pipelines such as ASE or high-throughput workflows, where a silent empty list would be indistinguishable from a successful run that just produced no results. Attributes ---------- structures: Structures that passed all filters. n_attempted: Total placement attempts made. n_passed: Number of structures that passed all filters (equals ``len(structures)`` unless the caller mutates the list). n_rejected_parity: Attempts rejected by the charge/multiplicity parity check. n_rejected_filter: Attempts rejected by user-supplied metric filters. n_success_target: The ``n_success`` value that was in effect during generation (``None`` when not set). Examples -------- Drop-in replacement for ``list[Structure]``:: result = generate(n_atoms=10, charge=0, mult=1, mode="gas", region="sphere:8", elements="6,7,8", n_samples=20, seed=0) for s in result: # iterates like a list print(s.to_xyz()) print(len(result)) # number that passed Inspect rejection metadata:: if result.n_rejected_parity > 0: print(f"{result.n_rejected_parity} samples failed parity check") print(result.summary()) Notes ----- ``GenerationResult`` is a :func:`~dataclasses.dataclass`; downstream code should treat it as immutable. The ``structures`` field is a plain ``list`` and may be sorted or sliced freely. """ structures: list[Structure] = field(default_factory=list) n_attempted: int = 0 n_passed: int = 0 n_rejected_parity: int = 0 n_rejected_filter: int = 0 n_success_target: int | None = None # ------------------------------------------------------------------ # # list-compatible interface # # ------------------------------------------------------------------ #
[docs] def __len__(self) -> int: return len(self.structures)
[docs] def __iter__(self) -> Iterator[Structure]: return iter(self.structures)
def __getitem__(self, index: int | slice) -> Structure | list[Structure]: if isinstance(index, slice): return self.structures[index] return self.structures[index] def __bool__(self) -> bool: return bool(self.structures) def __add__(self, other: GenerationResult) -> GenerationResult: """Merge two :class:`GenerationResult` objects into one. Combines structures and accumulates all counters so that batch workflows can collect results across multiple calls and treat them as a single result:: r1 = generate(..., n_samples=20, seed=0) r2 = generate(..., n_samples=20, seed=1) combined = r1 + r2 print(len(combined)) # up to 40 print(combined.summary()) Parameters ---------- other: Another :class:`GenerationResult` to merge into this one. Returns ------- GenerationResult New result containing all structures from both operands. ``n_success_target`` is taken from *self* when set, otherwise from *other*. """ if not isinstance(other, GenerationResult): return NotImplemented return GenerationResult( structures=self.structures + other.structures, n_attempted=self.n_attempted + other.n_attempted, n_passed=self.n_passed + other.n_passed, n_rejected_parity=self.n_rejected_parity + other.n_rejected_parity, n_rejected_filter=self.n_rejected_filter + other.n_rejected_filter, n_success_target=self.n_success_target if self.n_success_target is not None else other.n_success_target, )
[docs] def __repr__(self) -> str: return ( f"GenerationResult(" f"passed={self.n_passed}, " f"attempted={self.n_attempted}, " f"rejected_parity={self.n_rejected_parity}, " f"rejected_filter={self.n_rejected_filter})" )
# ------------------------------------------------------------------ # # Metadata helpers # # ------------------------------------------------------------------ #
[docs] def summary(self) -> str: """Return a human-readable one-line summary of the generation run. Returns ------- str E.g. ``"passed=5 attempted=20 rejected_parity=2 rejected_filter=13"``. """ parts = [ f"passed={self.n_passed}", f"attempted={self.n_attempted}", f"rejected_parity={self.n_rejected_parity}", f"rejected_filter={self.n_rejected_filter}", ] if self.n_success_target is not None: parts.append(f"n_success_target={self.n_success_target}") return " ".join(parts)
# --------------------------------------------------------------------------- # StructureGenerator # ---------------------------------------------------------------------------
[docs] class StructureGenerator: """Generate random atomic structures with disorder metrics. All parameters use Python snake_case names that correspond 1-to-1 with their CLI ``--flag`` counterparts. Parameters ---------- n_atoms: Number of atoms per structure (before optional H augmentation). charge: Total system charge (applied to every structure). mult: Spin multiplicity 2S+1. mode: Placement mode: ``"gas"`` (default), ``"chain"``, or ``"shell"``. region: [gas] Region spec: ``"sphere:R"`` | ``"box:L"`` | ``"box:LX,LY,LZ"``. Required when *mode="gas"*. branch_prob: [chain] Branching probability (default: 0.3). chain_persist: [chain] Directional persistence ∈ [0, 1] (default: 0.5). chain_bias: [chain] Global-axis drift strength ∈ [0, 1] (default: 0.0). The direction of the first bond becomes the bias axis; each subsequent step is blended toward that axis before normalisation. 0.0 → no bias (backwards-compatible); higher values produce more elongated structures with larger ``shape_aniso``. bond_range: [chain / shell tails] Bond-length range in Å (default: ``(1.2, 1.6)``). center_z: [shell] Atomic number of center atom. ``None`` → random per sample. coord_range: [shell] Coordination-number range (default: ``(4, 8)``). shell_radius: [shell] Shell-radius range in Å (default: ``(1.8, 2.5)``). elements: Element pool. A spec string such as ``"1-30"`` or ``"6,7,8"``, an explicit list of element symbols, or ``None`` for all Z = 1–106. element_fractions: Relative sampling weights for elements in the pool, as a ``{symbol: weight}`` dict (e.g. ``{"C": 0.5, "N": 0.3, "O": 0.2}``). Weights are *relative* — they are normalised internally and need not sum to 1. Elements absent from the dict receive a weight of 1.0. When ``None`` (default), every element in the pool is sampled with equal probability. element_min_counts: Minimum number of atoms per element guaranteed in every generated structure (e.g. ``{"C": 2, "N": 1}``). The required atoms are placed first; remaining slots are filled by weighted random sampling. ``None`` (default) → no lower bounds. The sum of all minimum counts must not exceed ``n_atoms``. element_max_counts: Maximum number of atoms allowed per element (e.g. ``{"N": 5, "O": 3}``). Elements that have reached their cap are excluded from sampling for the remaining slots. ``None`` (default) → no upper bounds. .. note:: When both *element_min_counts* and *element_max_counts* are given, each element's min must be ≤ its max. .. note:: The automatic hydrogen augmentation step (``add_hydrogen=True``) runs *after* the constrained sampling and may temporarily exceed *element_max_counts* for H. Set ``add_hydrogen=False`` if H count limits are critical. cov_scale: Minimum-distance scale factor: ``d_min(i,j) = cov_scale × (r_i + r_j)`` using Pyykkö (2009) single-bond covalent radii. Default: ``1.0``. relax_cycles: Maximum repulsion-relaxation iterations (default: 1500). add_hydrogen: Automatically append H atoms when H is in the pool but the sampled composition contains none (default: ``True``). n_samples: Maximum number of placement attempts (default: 1). Use ``0`` to allow unlimited attempts (only valid when *n_success* is also set, otherwise a :exc:`ValueError` is raised). n_success: Target number of structures that must pass all filters before generation stops (default: ``None``). - ``None`` → generate exactly *n_samples* attempts and return all that passed (original behaviour). - ``N > 0`` with ``n_samples > 0`` → stop as soon as *N* structures pass **or** *n_samples* attempts are exhausted, whichever comes first. Returns the structures collected so far with a warning if fewer than *N* were found. - ``N > 0`` with ``n_samples = 0`` → unlimited attempts; stop only when *N* structures have passed. seed: Random seed for reproducibility (``None`` → non-deterministic). n_bins: Histogram bins for ``H_spatial`` and ``RDF_dev`` (default: 20). w_atom: Weight of ``H_atom`` in ``H_total`` (default: 0.5). w_spatial: Weight of ``H_spatial`` in ``H_total`` (default: 0.5). cutoff: Distance cutoff in Å for Steinhardt and graph metrics. ``None`` → auto-computed as ``cov_scale × 1.5 × median(r_i + r_j)`` over the element pool. filters: Filter strings of the form ``"METRIC:MIN:MAX"`` (use ``"-"`` for an open bound). Only structures satisfying *all* filters are returned. verbose: Print progress and statistics to *stderr* (default: ``False``). The CLI always passes ``True``; library callers usually leave it off. Examples -------- Class API:: from pasted import StructureGenerator gen = StructureGenerator( n_atoms=12, charge=0, mult=1, mode="gas", region="sphere:9", elements="1-30", n_samples=50, seed=42, filters=["H_total:2.0:-"], ) structures = gen.generate() for s in structures: print(s) Functional API:: from pasted import generate structures = generate( n_atoms=12, charge=0, mult=1, mode="chain", elements="6,7,8", n_samples=20, seed=0, ) """ def __init__( self, *, n_atoms: int, charge: int, mult: int, mode: str = "gas", region: str | None = None, branch_prob: float = 0.3, chain_persist: float = 0.5, chain_bias: float = 0.0, bond_range: tuple[float, float] = (1.2, 1.6), center_z: int | None = None, coord_range: tuple[int, int] = (4, 8), shell_radius: tuple[float, float] = (1.8, 2.5), elements: str | list[str] | None = None, element_fractions: dict[str, float] | None = None, element_min_counts: dict[str, int] | None = None, element_max_counts: dict[str, int] | None = None, cov_scale: float = 1.0, relax_cycles: int = 1500, maxent_steps: int = 300, maxent_lr: float = 0.05, maxent_cutoff_scale: float = 2.5, trust_radius: float = 0.5, convergence_tol: float = 1e-3, add_hydrogen: bool = True, n_samples: int = 1, n_success: int | None = None, seed: int | None = None, n_bins: int = 20, w_atom: float = 0.5, w_spatial: float = 0.5, cutoff: float | None = None, filters: list[str] | None = None, verbose: bool = False, ) -> None: if mode not in ("gas", "chain", "shell", "maxent"): raise ValueError( f"mode must be 'gas', 'chain', 'shell', or 'maxent'; got {mode!r}" ) if mode in ("gas", "maxent") and region is None: raise ValueError("region is required when mode='gas' or mode='maxent'") self.n_atoms = n_atoms self.charge = charge self.mult = mult self.mode = mode self.region = region self.branch_prob = branch_prob self.chain_persist = chain_persist self.chain_bias = chain_bias self.bond_range = bond_range self.center_z = center_z self.coord_range = coord_range self.shell_radius = shell_radius self.cov_scale = cov_scale self.relax_cycles = relax_cycles self.maxent_steps = maxent_steps self.maxent_lr = maxent_lr self.maxent_cutoff_scale = maxent_cutoff_scale self.trust_radius = trust_radius self.convergence_tol = convergence_tol self._add_hydrogen = add_hydrogen self.n_samples = n_samples self.n_success = n_success self.seed = seed self.n_bins = n_bins self.w_atom = w_atom self.w_spatial = w_spatial self.verbose = verbose # ── n_samples / n_success validation ──────────────────────────── if n_samples == 0 and n_success is None: raise ValueError( "n_samples=0 (unlimited) requires n_success to be set; " "otherwise generation would run forever." ) if n_success is not None and n_success < 1: raise ValueError(f"n_success must be >= 1; got {n_success}.") # ── Element pool ──────────────────────────────────────────────── if elements is None: self._element_pool: list[str] = default_element_pool() elif isinstance(elements, str): self._element_pool = parse_element_spec(elements) else: self._element_pool = list(elements) # ── Element fractions ──────────────────────────────────────────── # Build a normalised weight list aligned with self._element_pool. # Unknown keys in the dict raise ValueError immediately. if element_fractions is not None: unknown = set(element_fractions) - set(self._element_pool) if unknown: raise ValueError( f"element_fractions contains symbols not in the element pool: " f"{sorted(unknown)}" ) weights = [float(element_fractions.get(sym, 1.0)) for sym in self._element_pool] if any(w < 0 for w in weights): raise ValueError("element_fractions weights must be non-negative.") total = sum(weights) if total == 0: raise ValueError("element_fractions weights must not all be zero.") self._element_weights: list[float] = [w / total for w in weights] else: n = len(self._element_pool) self._element_weights = [1.0 / n] * n # ── Element min/max counts ─────────────────────────────────────── # Validate and store; actual enforcement happens in stream(). if element_min_counts is not None: unknown_min = set(element_min_counts) - set(self._element_pool) if unknown_min: raise ValueError( f"element_min_counts contains symbols not in the element pool: " f"{sorted(unknown_min)}" ) if any(v < 0 for v in element_min_counts.values()): raise ValueError("element_min_counts values must be non-negative.") total_min = sum(element_min_counts.values()) if total_min > n_atoms: raise ValueError( f"Sum of element_min_counts ({total_min}) exceeds n_atoms ({n_atoms})." ) if element_max_counts is not None: unknown_max = set(element_max_counts) - set(self._element_pool) if unknown_max: raise ValueError( f"element_max_counts contains symbols not in the element pool: " f"{sorted(unknown_max)}" ) if any(v < 0 for v in element_max_counts.values()): raise ValueError("element_max_counts values must be non-negative.") if element_min_counts is not None and element_max_counts is not None: for sym in element_min_counts: lo = element_min_counts[sym] hi = element_max_counts.get(sym, lo) if lo > hi: raise ValueError( f"element_min_counts[{sym!r}]={lo} > " f"element_max_counts[{sym!r}]={hi}." ) self._element_min_counts: dict[str, int] = dict(element_min_counts or {}) self._element_max_counts: dict[str, int] = dict(element_max_counts or {}) # ── Filters ───────────────────────────────────────────────────── self._filters: list[tuple[str, float, float]] = [parse_filter(f) for f in (filters or [])] # ── Cutoff ────────────────────────────────────────────────────── self._cutoff: float = self._resolve_cutoff(cutoff) # ── Shell center ───────────────────────────────────────────────── self._fixed_center_sym: str | None = None if mode == "shell" and center_z is not None: if center_z not in _Z_TO_SYM: raise ValueError(f"center_z={center_z}: unknown atomic number.") sym = _Z_TO_SYM[center_z] if sym not in self._element_pool: raise ValueError(f"center_z={center_z} ({sym}) is not in the element pool.") self._fixed_center_sym = sym if self.verbose: self._log(f"[pool] {len(self._element_pool)} elements in pool") if mode == "shell": if self._fixed_center_sym: self._log( f"[shell] center fixed: {self._fixed_center_sym} " f"(Z={ATOMIC_NUMBERS[self._fixed_center_sym]})" ) else: self._log("[shell] center: random per sample (chaos mode)") # ------------------------------------------------------------------ # # Internal helpers # # ------------------------------------------------------------------ # def _log(self, msg: str) -> None: """Print *msg* to stderr when verbose mode is active.""" print(msg, file=sys.stderr) def _resolve_cutoff(self, override: float | None) -> float: if override is not None: if self.verbose: self._log(f"[cutoff] {override:.3f} Å (user-specified)") return override radii = [_cov_radius_ang(s) for s in self._element_pool] pair_sums = sorted(ra + rb for i, ra in enumerate(radii) for rb in radii[i:]) median_sum = pair_sums[len(pair_sums) // 2] cutoff = self.cov_scale * 1.5 * median_sum if self.verbose: self._log( f"[cutoff] {cutoff:.3f} Å (auto: cov_scale={self.cov_scale} × 1.5 × " f"median(r_i+r_j)={median_sum:.3f} Å)" ) return cutoff def _sample_atoms(self, rng: random.Random) -> list[str]: """Sample *n_atoms* element symbols respecting fractions and count bounds. Algorithm --------- 1. If no fractions/min/max are configured, falls back to the original uniform ``rng.choice`` per atom (preserves seed parity). 2. Otherwise: place the guaranteed minimum-count atoms first (``element_min_counts``), fill remaining slots by weighted random sampling (``element_fractions``), excluding elements that have reached their ``element_max_counts`` cap, then shuffle. Raises ------ RuntimeError When the constraints cannot be satisfied (e.g. all remaining elements are capped and there are still slots to fill). """ pool = self._element_pool min_c = self._element_min_counts max_c = self._element_max_counts n = len(pool) uniform = (n > 0 and all(abs(w - 1.0 / n) < 1e-12 for w in self._element_weights)) # Fast path: uniform weights, no bounds → identical to original behaviour if uniform and not min_c and not max_c: return [rng.choice(pool) for _ in range(self.n_atoms)] weights = self._element_weights # ── Step 1: fill guaranteed minimum counts ────────────────────── counts: dict[str, int] = {sym: min_c.get(sym, 0) for sym in pool} atoms: list[str] = [] for sym in pool: atoms.extend([sym] * counts[sym]) remaining = self.n_atoms - len(atoms) # ── Step 2: weighted sampling for remaining slots ──────────────── for _ in range(remaining): # Build eligible pool (not yet capped) eligible: list[str] = [] eligible_w: list[float] = [] for sym, w in zip(pool, weights, strict=True): cap = max_c.get(sym, None) if cap is None or counts.get(sym, 0) < cap: eligible.append(sym) eligible_w.append(w) if not eligible: raise RuntimeError( "element_max_counts constraints cannot be satisfied: " "all elements are capped before n_atoms is reached." ) # Normalise eligible weights and do a weighted choice total_w = sum(eligible_w) cum: list[float] = [] acc = 0.0 for w in eligible_w: acc += w / total_w cum.append(acc) r = rng.random() chosen = eligible[-1] for sym, c in zip(eligible, cum, strict=True): if r <= c: chosen = sym break counts[chosen] = counts.get(chosen, 0) + 1 atoms.append(chosen) # ── Step 3: shuffle so forced atoms don't cluster at front ─────── rng.shuffle(atoms) return atoms # ------------------------------------------------------------------ # # Public properties # # ------------------------------------------------------------------ # @property def element_pool(self) -> list[str]: """A copy of the resolved element pool (list of symbols).""" return list(self._element_pool) @property def cutoff(self) -> float: """Distance cutoff in Å used for Steinhardt and graph metrics.""" return self._cutoff # ------------------------------------------------------------------ # # Generation # # ------------------------------------------------------------------ # # ------------------------------------------------------------------ # # Internal placement dispatch # # ------------------------------------------------------------------ # def _place_one( self, atoms_list: list[str], rng: random.Random, ) -> tuple[list[str], list[Vec3], str | None]: """Run the mode-specific placement and return (atoms, positions, center_sym). Raises ------ RuntimeError, ValueError Propagated from the underlying placement functions. """ bond_lo, bond_hi = self.bond_range shell_lo, shell_hi = self.shell_radius coord_lo, coord_hi = self.coord_range center_sym: str | None = None if self.mode == "gas": assert self.region is not None # guaranteed by __init__ validation atoms_out, positions = place_gas( atoms_list, self.region, rng, ) elif self.mode == "chain": atoms_out, positions = place_chain( atoms_list, bond_lo, bond_hi, self.branch_prob, self.chain_persist, rng, chain_bias=self.chain_bias, ) elif self.mode == "maxent": assert self.region is not None atoms_out, positions = place_maxent( atoms_list, self.region, self.cov_scale, rng, maxent_steps=self.maxent_steps, maxent_lr=self.maxent_lr, maxent_cutoff_scale=self.maxent_cutoff_scale, trust_radius=self.trust_radius, convergence_tol=self.convergence_tol, seed=self.seed, ) else: # shell center_sym = ( self._fixed_center_sym if self._fixed_center_sym is not None else rng.choice(atoms_list) ) atoms_out, positions = place_shell( atoms_list, center_sym, coord_lo, coord_hi, shell_lo, shell_hi, bond_lo, bond_hi, rng, ) return atoms_out, positions, center_sym # ------------------------------------------------------------------ # # Generation # # ------------------------------------------------------------------ #
[docs] def stream(self) -> Iterator[Structure]: """Generate structures one by one, yielding each that passes all filters. Unlike :meth:`generate`, structures are yielded immediately as they pass, so callers can write output or stop early without waiting for all attempts to complete. Respects both *n_samples* (maximum attempts) and *n_success* (target number of passing structures): - If *n_success* is set, the iterator stops as soon as that many structures have been yielded — even if *n_samples* attempts have not been exhausted. - If *n_samples* is ``0`` (unlimited), the iterator runs until *n_success* structures have been yielded. - If *n_samples* attempts are exhausted before *n_success* is reached, a warning is emitted to *stderr* and the iterator ends. Each call creates a fresh :class:`random.Random` seeded with ``self.seed``, so repeated calls with the same seed are reproducible. Yields ------ Structure Each structure that passed all filters, in generation order. Examples -------- Write structures to a file as they are found:: gen = StructureGenerator( n_atoms=12, charge=0, mult=1, mode="gas", region="sphere:9", elements="1-30", n_success=10, n_samples=500, seed=42, ) for s in gen.stream(): s.write_xyz("out.xyz") """ rng = random.Random(self.seed) if self.verbose and self._filters: self._log( "[filter] " + ", ".join(f"{m} in [{lo:.4g},{hi:.4g}]" for m, lo, hi in self._filters) ) do_add_h = ("H" in self._element_pool) and self._add_hydrogen n_passed = n_invalid = n_attempted = n_rejected_filter = 0 unlimited = (self.n_samples == 0) denom = "∞" if unlimited else str(self.n_samples) width = len(denom) while True: # Stop conditions if not unlimited and n_attempted >= self.n_samples: break if self.n_success is not None and n_passed >= self.n_success: break i = n_attempted n_attempted += 1 atoms_list = self._sample_atoms(rng) if do_add_h: atoms_list = add_hydrogen(atoms_list, rng) ok, val_msg = validate_charge_mult(atoms_list, self.charge, self.mult) if not ok: n_invalid += 1 if self.verbose: self._log(f"[{i + 1:>{width}}/{denom}:invalid] {val_msg}") continue try: atoms_out, positions, center_sym = self._place_one(atoms_list, rng) except (RuntimeError, ValueError) as exc: if self.verbose: self._log(f"[ERROR] sample {i + 1}: {exc}") raise positions, converged = relax_positions( atoms_out, positions, self.cov_scale, self.relax_cycles, seed=self.seed ) if not converged and self.verbose: self._log( f"[{i + 1:>{width}}/{denom}:warn] " f"relax_positions did not converge in {self.relax_cycles} cycles." ) metrics = compute_all_metrics( atoms_out, positions, self.n_bins, self.w_atom, self.w_spatial, self._cutoff, self.cov_scale, ) passed = passes_filters(metrics, self._filters) if self.verbose: flag = "PASS" if passed else "skip" self._log( f"[{i + 1:>{width}}/{denom}:{flag}] " + " ".join(f"{k}={_fmt(v)}" for k, v in metrics.items()) ) if not passed: n_rejected_filter += 1 continue n_passed += 1 yield Structure( atoms=atoms_out, positions=positions, charge=self.charge, mult=self.mult, metrics=metrics, mode=self.mode, sample_index=n_passed, center_sym=center_sym if self.mode == "shell" else None, seed=self.seed, ) n_skip = n_attempted - n_passed - n_invalid if self.verbose: self._log( f"[summary] attempted={n_attempted} passed={n_passed} " f"rejected_parity={n_invalid} rejected_filter={n_skip}" ) # ── warnings.warn for noteworthy outcomes ───────────────────────── # Fires regardless of verbose so that downstream consumers # (ASE, HT pipelines) receive machine-visible signals even when # PASTED is not in verbose mode. # # Parity warnings fire only when n_passed == 0 (complete failure). # Partial parity rejection where some structures still passed is # expected behaviour for mixed-element pools and does not require # a warning — the verbose summary line already reports the counts. if n_invalid > 0 and n_passed == 0: warnings.warn( f"All {n_attempted} attempt(s) were rejected by the charge/" f"multiplicity parity check ({n_invalid} invalid). " f"No structures were generated. " f"Check that your element pool can satisfy " f"charge={self.charge}, mult={self.mult}.", UserWarning, stacklevel=4, ) if n_passed == 0 and n_invalid == 0: warnings.warn( f"No structures passed the metric filters after " f"{n_attempted} attempt(s) " f"({n_skip} rejected by filters). " f"Try relaxing the --filter thresholds or increasing n_samples.", UserWarning, stacklevel=4, ) elif ( self.n_success is not None and n_passed < self.n_success and not unlimited ): warnings.warn( f"Attempt budget exhausted ({n_attempted} attempts) before " f"reaching n_success={self.n_success}; " f"only {n_passed} structure(s) collected. " f"Increase n_samples or relax filters.", UserWarning, stacklevel=4, ) # Store run statistics so generate() can build a GenerationResult # without re-running the generator loop. self._last_run_stats: dict[str, int] = { "n_attempted": n_attempted, "n_passed": n_passed, "n_rejected_parity": n_invalid, "n_rejected_filter": n_rejected_filter, }
[docs] def generate(self) -> GenerationResult: """Generate structures and return a :class:`GenerationResult`. Collects all structures yielded by :meth:`stream`, attaches generation metadata (attempt counts, rejection breakdowns), and returns a :class:`GenerationResult` that behaves like a ``list[Structure]`` in all normal usage while also carrying the diagnostics needed for automated pipelines. :class:`GenerationResult` supports the full ``list`` interface (indexing, iteration, ``len``, ``bool``) so existing code that does ``result[0]`` or ``for s in result`` continues to work without modification. Warnings are also emitted via :func:`warnings.warn` (category :class:`UserWarning`) when: - Any attempts are rejected by the charge/multiplicity parity check. - No structures pass the metric filters. - The attempt budget is exhausted before ``n_success`` is reached. Each call creates a fresh :class:`random.Random` seeded with ``self.seed``, so repeated calls with the same seed are reproducible. Returns ------- GenerationResult Wraps the list of passing structures together with generation metadata. Use ``result.structures`` for the raw list or ``result.summary()`` for a one-line diagnostic string. Examples -------- Drop-in list usage:: result = gen.generate() for s in result: print(s.to_xyz()) Metadata access:: result = gen.generate() if result.n_rejected_parity > 0: print(result.summary()) """ structures = list(self.stream()) stats: dict[str, int] = getattr(self, "_last_run_stats", {}) return GenerationResult( structures=structures, n_attempted=stats.get("n_attempted", len(structures)), n_passed=stats.get("n_passed", len(structures)), n_rejected_parity=stats.get("n_rejected_parity", 0), n_rejected_filter=stats.get("n_rejected_filter", 0), n_success_target=self.n_success, )
# ------------------------------------------------------------------ # # Iteration support # # ------------------------------------------------------------------ #
[docs] def __iter__(self) -> Iterator[Structure]: """Iterate over generated structures (delegates to :meth:`stream`).""" return self.stream()
[docs] def __repr__(self) -> str: return ( f"StructureGenerator(" f"n_atoms={self.n_atoms}, mode={self.mode!r}, " f"charge={self.charge:+d}, mult={self.mult}, " f"n_samples={self.n_samples}, " f"n_success={self.n_success}, " f"pool_size={len(self._element_pool)})" )
# --------------------------------------------------------------------------- # Functional API # --------------------------------------------------------------------------- def read_xyz( source: str | Path, *, recompute_metrics: bool = True, cutoff: float | None = None, n_bins: int = 20, w_atom: float = 0.5, w_spatial: float = 0.5, cov_scale: float = 1.0, ) -> list[Structure]: """Read one or more structures from an XYZ file or string. Convenience wrapper around :meth:`Structure.from_xyz` that reads **all frames** from a (possibly multi-frame) XYZ source and returns them as a list. Both plain XYZ and PASTED extended XYZ are supported. Parameters ---------- source: Path to an XYZ file **or** a raw XYZ string. recompute_metrics: Recompute all disorder metrics after loading each structure (default: ``True``). cutoff: Distance cutoff (Å) for metric computation. Auto-computed from each structure's element pool when ``None``. n_bins: Histogram bins for ``H_spatial`` / ``RDF_dev`` (default: 20). w_atom: Weight of ``H_atom`` in ``H_total`` (default: 0.5). w_spatial: Weight of ``H_spatial`` in ``H_total`` (default: 0.5). cov_scale: Minimum distance scale factor used for metrics (default: 1.0). Returns ------- list[Structure] One :class:`Structure` per frame, in file order. Examples -------- Load a PASTED output file and pass the first structure to the optimizer:: from pasted import read_xyz, StructureOptimizer structs = read_xyz("results.xyz") opt = StructureOptimizer( n_atoms=len(structs[0]), charge=structs[0].charge, mult=structs[0].mult, objective={"H_total": 1.0}, elements=list(set(structs[0].atoms)), max_steps=3000, seed=42, ) result = opt.run(initial=structs[0]) Compose with :class:`GenerationResult` via ``+``:: from pasted import read_xyz, generate existing = generate(n_atoms=10, charge=0, mult=1, mode="gas", region="sphere:9", elements="6,7,8", n_samples=5, seed=0) loaded = read_xyz("previous_run.xyz") # loaded is a list[Structure]; wrap manually if needed: from pasted import GenerationResult all_structs = existing + GenerationResult(structures=loaded, n_passed=len(loaded), n_attempted=len(loaded)) """ p = Path(source) if not isinstance(source, str) or "\n" not in str(source) else None text = p.read_text() if (p is not None and p.exists()) else str(source) frames = parse_xyz(text) result: list[Structure] = [] for atoms, positions, charge, mult, embedded_metrics in frames: if recompute_metrics: cut = cutoff if cut is None: radii = [_cov_radius_ang(a) for a in atoms] pair_sums = sorted( ra + rb for i, ra in enumerate(radii) for rb in radii[i:] ) median_sum = pair_sums[len(pair_sums) // 2] cut = cov_scale * 1.5 * median_sum metrics = compute_all_metrics( atoms, positions, n_bins, w_atom, w_spatial, cut, cov_scale ) else: metrics = embedded_metrics result.append(Structure( atoms=list(atoms), positions=list(positions), charge=charge, mult=mult, metrics=metrics, mode="loaded_xyz", )) return result
[docs] def generate( *, n_atoms: int, charge: int, mult: int, mode: str = "gas", region: str | None = None, branch_prob: float = 0.3, chain_persist: float = 0.5, chain_bias: float = 0.0, bond_range: tuple[float, float] = (1.2, 1.6), center_z: int | None = None, coord_range: tuple[int, int] = (4, 8), shell_radius: tuple[float, float] = (1.8, 2.5), elements: str | list[str] | None = None, element_fractions: dict[str, float] | None = None, element_min_counts: dict[str, int] | None = None, element_max_counts: dict[str, int] | None = None, cov_scale: float = 1.0, relax_cycles: int = 1500, maxent_steps: int = 300, maxent_lr: float = 0.05, maxent_cutoff_scale: float = 2.5, trust_radius: float = 0.5, convergence_tol: float = 1e-3, add_hydrogen: bool = True, n_samples: int = 1, n_success: int | None = None, seed: int | None = None, n_bins: int = 20, w_atom: float = 0.5, w_spatial: float = 0.5, cutoff: float | None = None, filters: list[str] | None = None, verbose: bool = False, ) -> GenerationResult: """Create a :class:`StructureGenerator` and immediately call :meth:`~StructureGenerator.generate`. All parameters are forwarded unchanged. See :class:`StructureGenerator` for full documentation. Returns ------- GenerationResult A list-compatible object containing the structures that passed all filters plus metadata about the generation run (attempt counts, rejection breakdowns). Behaves identically to ``list[Structure]`` in all normal usage (indexing, iteration, ``len``, ``bool``). :class:`UserWarning` is raised whenever: - attempts are rejected by the charge/multiplicity parity check, - no structures pass the metric filters, or - the attempt budget is exhausted before ``n_success`` is reached. Examples -------- Drop-in list usage:: from pasted import generate # 20 random gas-phase structures drawn from C/N/O structures = generate( n_atoms=10, charge=0, mult=1, mode="gas", region="sphere:8", elements="6,7,8", n_samples=20, seed=0, ) for i, s in enumerate(structures): s.write_xyz("out.xyz", append=(i > 0)) Inspecting rejection metadata:: result = generate( n_atoms=10, charge=0, mult=1, mode="gas", region="sphere:8", elements="6,7,8", n_samples=50, seed=0, filters=["H_total:1.5:-"], ) print(result.summary()) # e.g. "passed=3 attempted=50 rejected_parity=0 rejected_filter=47" """ gen = StructureGenerator( n_atoms=n_atoms, charge=charge, mult=mult, mode=mode, region=region, branch_prob=branch_prob, chain_persist=chain_persist, chain_bias=chain_bias, bond_range=bond_range, center_z=center_z, coord_range=coord_range, shell_radius=shell_radius, elements=elements, element_fractions=element_fractions, element_min_counts=element_min_counts, element_max_counts=element_max_counts, cov_scale=cov_scale, relax_cycles=relax_cycles, maxent_steps=maxent_steps, maxent_lr=maxent_lr, maxent_cutoff_scale=maxent_cutoff_scale, trust_radius=trust_radius, convergence_tol=convergence_tol, add_hydrogen=add_hydrogen, n_samples=n_samples, n_success=n_success, seed=seed, n_bins=n_bins, w_atom=w_atom, w_spatial=w_spatial, cutoff=cutoff, filters=filters, verbose=verbose, ) return gen.generate()