"""
pasted._generator
=================
High-level API:
- :class:`Structure` — dataclass holding one generated structure.
- :class:`StructureGenerator` — stateful generator (class API).
- :func:`generate` — convenience functional wrapper.
"""
from __future__ import annotations
import random
import sys
import warnings
from collections import Counter
from collections.abc import Iterator
from dataclasses import dataclass, field
from pathlib import Path
from ._atoms import (
_Z_TO_SYM,
ATOMIC_NUMBERS,
_cov_radius_ang,
default_element_pool,
parse_element_spec,
parse_filter,
validate_charge_mult,
)
from ._io import _fmt, format_xyz, parse_xyz
from ._metrics import compute_all_metrics, passes_filters
from ._placement import (
Vec3,
add_hydrogen,
place_chain,
place_gas,
place_maxent,
place_shell,
relax_positions,
)
# ---------------------------------------------------------------------------
# Structure dataclass
# ---------------------------------------------------------------------------
[docs]
@dataclass
class Structure:
"""A single generated atomic structure with its computed disorder metrics.
Attributes
----------
atoms:
Element symbols, one per atom.
positions:
Cartesian coordinates in Å, one ``(x, y, z)`` tuple per atom.
charge:
Total system charge.
mult:
Spin multiplicity 2S+1.
metrics:
Computed disorder metrics (see :data:`pasted._atoms.ALL_METRICS`).
mode:
Placement mode used (``"gas"``, ``"chain"``, or ``"shell"``).
sample_index:
1-based index within the batch of structures that passed filters.
center_sym:
Element symbol of the shell center atom (shell mode only).
seed:
Random seed used for generation (``None`` if unseeded).
"""
atoms: list[str]
positions: list[Vec3]
charge: int
mult: int
metrics: dict[str, float]
mode: str
sample_index: int = 0
center_sym: str | None = None
seed: int | None = None
# ------------------------------------------------------------------ #
# XYZ output #
# ------------------------------------------------------------------ #
[docs]
def to_xyz(self, prefix: str = "") -> str:
"""Serialise to extended XYZ format.
Parameters
----------
prefix:
Custom prefix for the comment line. When omitted the standard
``"sample=N mode=M …"`` string is generated automatically.
Returns
-------
Multi-line string (no trailing newline).
"""
if not prefix:
prefix = f"sample={self.sample_index} mode={self.mode}"
if self.mode == "shell" and self.center_sym:
prefix += f" center={self.center_sym}(Z={ATOMIC_NUMBERS[self.center_sym]})"
if self.seed is not None:
prefix += f" seed={self.seed}"
return format_xyz(
self.atoms,
self.positions,
self.charge,
self.mult,
self.metrics,
prefix,
)
[docs]
def write_xyz(self, path: str | Path, *, append: bool = True) -> None:
"""Write this structure to an XYZ file.
Parameters
----------
path:
Output file path.
append:
If ``True`` (default) the file is opened in append mode so that
multiple structures can be written in sequence. Use
``append=False`` to overwrite.
"""
mode = "a" if append else "w"
with Path(path).open(mode) as fh:
fh.write(self.to_xyz() + "\n")
# ------------------------------------------------------------------ #
# Dunder helpers #
# ------------------------------------------------------------------ #
[docs]
def __len__(self) -> int:
return len(self.atoms)
# ------------------------------------------------------------------ #
# XYZ import #
# ------------------------------------------------------------------ #
[docs]
@classmethod
def from_xyz(
cls,
source: str | Path,
*,
frame: int = 0,
recompute_metrics: bool = True,
cutoff: float | None = None,
n_bins: int = 20,
w_atom: float = 0.5,
w_spatial: float = 0.5,
cov_scale: float = 1.0,
) -> Structure:
"""Load a :class:`Structure` from an XYZ file or string.
Supports both plain XYZ and PASTED extended XYZ (with ``charge=``,
``mult=``, and metric tokens on the comment line). When
*recompute_metrics* is ``True`` (default), all disorder metrics are
recomputed from the loaded geometry so that the returned structure
is fully usable as optimizer input or for filtering.
Parameters
----------
source:
Path to an XYZ file **or** a raw XYZ string.
frame:
Zero-based frame index when *source* contains multiple
concatenated structures (default: first frame).
recompute_metrics:
Recompute all disorder metrics after loading. Set to ``False``
to skip the recomputation and return the structure with whatever
metric values were embedded in the extended XYZ comment (or an
empty dict for plain XYZ).
cutoff:
Distance cutoff (Å) for metric computation. Auto-computed from
the element pool when ``None``.
n_bins:
Histogram bins for ``H_spatial`` / ``RDF_dev`` (default: 20).
w_atom:
Weight of ``H_atom`` in ``H_total`` (default: 0.5).
w_spatial:
Weight of ``H_spatial`` in ``H_total`` (default: 0.5).
cov_scale:
Minimum distance scale factor used for metrics (default: 1.0).
Returns
-------
Structure
Raises
------
ValueError
When the file / string cannot be parsed, or *frame* is out of
range.
Examples
--------
Load and immediately use as optimizer initial structure::
from pasted import Structure, StructureOptimizer
s = Structure.from_xyz("my_structure.xyz")
opt = StructureOptimizer(
n_atoms=len(s), charge=s.charge, mult=s.mult,
objective={"H_total": 1.0},
elements=[sym for sym in set(s.atoms)],
max_steps=2000, seed=42,
)
result = opt.run(initial=s)
"""
p = Path(source) if not isinstance(source, str) or "\n" not in str(source) else None
if p is not None and p.exists():
text = p.read_text()
else:
text = str(source)
frames = parse_xyz(text)
if not frames:
raise ValueError("No frames found in XYZ source.")
if frame < 0 or frame >= len(frames):
raise ValueError(
f"frame={frame} out of range; source contains {len(frames)} frame(s)."
)
atoms, positions, charge, mult, embedded_metrics = frames[frame]
if recompute_metrics:
if cutoff is None:
radii = [_cov_radius_ang(a) for a in atoms]
pair_sums = sorted(
ra + rb for i, ra in enumerate(radii) for rb in radii[i:]
)
median_sum = pair_sums[len(pair_sums) // 2]
cutoff = cov_scale * 1.5 * median_sum
metrics = compute_all_metrics(
atoms, positions, n_bins, w_atom, w_spatial, cutoff, cov_scale
)
else:
metrics = embedded_metrics
return cls(
atoms=list(atoms),
positions=list(positions),
charge=charge,
mult=mult,
metrics=metrics,
mode="loaded_xyz",
)
[docs]
def __repr__(self) -> str:
counts = Counter(self.atoms)
comp = "".join(f"{sym}{n}" if n > 1 else sym for sym, n in sorted(counts.items()))
h_total = self.metrics.get("H_total", float("nan"))
return f"Structure(n={len(self)}, comp={comp!r}, mode={self.mode!r}, H_total={h_total:.3f})"
# ---------------------------------------------------------------------------
# GenerationResult
# ---------------------------------------------------------------------------
[docs]
@dataclass
class GenerationResult:
"""Return value of :func:`generate` and :meth:`StructureGenerator.generate`.
Behaves like a ``list[Structure]`` in all normal usage (indexing,
iteration, ``len``, boolean test, ``for s in result``) while also
carrying metadata about how many attempts were made and why samples
were rejected. This metadata is especially useful when integrating
PASTED into automated pipelines such as ASE or high-throughput
workflows, where a silent empty list would be indistinguishable from
a successful run that just produced no results.
Attributes
----------
structures:
Structures that passed all filters.
n_attempted:
Total placement attempts made.
n_passed:
Number of structures that passed all filters (equals
``len(structures)`` unless the caller mutates the list).
n_rejected_parity:
Attempts rejected by the charge/multiplicity parity check.
n_rejected_filter:
Attempts rejected by user-supplied metric filters.
n_success_target:
The ``n_success`` value that was in effect during generation
(``None`` when not set).
Examples
--------
Drop-in replacement for ``list[Structure]``::
result = generate(n_atoms=10, charge=0, mult=1,
mode="gas", region="sphere:8",
elements="6,7,8", n_samples=20, seed=0)
for s in result: # iterates like a list
print(s.to_xyz())
print(len(result)) # number that passed
Inspect rejection metadata::
if result.n_rejected_parity > 0:
print(f"{result.n_rejected_parity} samples failed parity check")
print(result.summary())
Notes
-----
``GenerationResult`` is a :func:`~dataclasses.dataclass`; downstream
code should treat it as immutable. The ``structures`` field is a
plain ``list`` and may be sorted or sliced freely.
"""
structures: list[Structure] = field(default_factory=list)
n_attempted: int = 0
n_passed: int = 0
n_rejected_parity: int = 0
n_rejected_filter: int = 0
n_success_target: int | None = None
# ------------------------------------------------------------------ #
# list-compatible interface #
# ------------------------------------------------------------------ #
[docs]
def __len__(self) -> int:
return len(self.structures)
[docs]
def __iter__(self) -> Iterator[Structure]:
return iter(self.structures)
def __getitem__(self, index: int | slice) -> Structure | list[Structure]:
if isinstance(index, slice):
return self.structures[index]
return self.structures[index]
def __bool__(self) -> bool:
return bool(self.structures)
def __add__(self, other: GenerationResult) -> GenerationResult:
"""Merge two :class:`GenerationResult` objects into one.
Combines structures and accumulates all counters so that batch
workflows can collect results across multiple calls and treat them
as a single result::
r1 = generate(..., n_samples=20, seed=0)
r2 = generate(..., n_samples=20, seed=1)
combined = r1 + r2
print(len(combined)) # up to 40
print(combined.summary())
Parameters
----------
other:
Another :class:`GenerationResult` to merge into this one.
Returns
-------
GenerationResult
New result containing all structures from both operands.
``n_success_target`` is taken from *self* when set, otherwise
from *other*.
"""
if not isinstance(other, GenerationResult):
return NotImplemented
return GenerationResult(
structures=self.structures + other.structures,
n_attempted=self.n_attempted + other.n_attempted,
n_passed=self.n_passed + other.n_passed,
n_rejected_parity=self.n_rejected_parity + other.n_rejected_parity,
n_rejected_filter=self.n_rejected_filter + other.n_rejected_filter,
n_success_target=self.n_success_target
if self.n_success_target is not None
else other.n_success_target,
)
[docs]
def __repr__(self) -> str:
return (
f"GenerationResult("
f"passed={self.n_passed}, "
f"attempted={self.n_attempted}, "
f"rejected_parity={self.n_rejected_parity}, "
f"rejected_filter={self.n_rejected_filter})"
)
# ------------------------------------------------------------------ #
# Metadata helpers #
# ------------------------------------------------------------------ #
[docs]
def summary(self) -> str:
"""Return a human-readable one-line summary of the generation run.
Returns
-------
str
E.g. ``"passed=5 attempted=20 rejected_parity=2 rejected_filter=13"``.
"""
parts = [
f"passed={self.n_passed}",
f"attempted={self.n_attempted}",
f"rejected_parity={self.n_rejected_parity}",
f"rejected_filter={self.n_rejected_filter}",
]
if self.n_success_target is not None:
parts.append(f"n_success_target={self.n_success_target}")
return " ".join(parts)
# ---------------------------------------------------------------------------
# StructureGenerator
# ---------------------------------------------------------------------------
[docs]
class StructureGenerator:
"""Generate random atomic structures with disorder metrics.
All parameters use Python snake_case names that correspond 1-to-1 with
their CLI ``--flag`` counterparts.
Parameters
----------
n_atoms:
Number of atoms per structure (before optional H augmentation).
charge:
Total system charge (applied to every structure).
mult:
Spin multiplicity 2S+1.
mode:
Placement mode: ``"gas"`` (default), ``"chain"``, or ``"shell"``.
region:
[gas] Region spec: ``"sphere:R"`` | ``"box:L"`` | ``"box:LX,LY,LZ"``.
Required when *mode="gas"*.
branch_prob:
[chain] Branching probability (default: 0.3).
chain_persist:
[chain] Directional persistence ∈ [0, 1] (default: 0.5).
chain_bias:
[chain] Global-axis drift strength ∈ [0, 1] (default: 0.0).
The direction of the first bond becomes the bias axis; each
subsequent step is blended toward that axis before normalisation.
0.0 → no bias (backwards-compatible); higher values produce more
elongated structures with larger ``shape_aniso``.
bond_range:
[chain / shell tails] Bond-length range in Å (default: ``(1.2, 1.6)``).
center_z:
[shell] Atomic number of center atom. ``None`` → random per sample.
coord_range:
[shell] Coordination-number range (default: ``(4, 8)``).
shell_radius:
[shell] Shell-radius range in Å (default: ``(1.8, 2.5)``).
elements:
Element pool. A spec string such as ``"1-30"`` or ``"6,7,8"``, an
explicit list of element symbols, or ``None`` for all Z = 1–106.
element_fractions:
Relative sampling weights for elements in the pool, as a
``{symbol: weight}`` dict (e.g. ``{"C": 0.5, "N": 0.3, "O": 0.2}``).
Weights are *relative* — they are normalised internally and need not
sum to 1. Elements absent from the dict receive a weight of 1.0.
When ``None`` (default), every element in the pool is sampled with
equal probability.
element_min_counts:
Minimum number of atoms per element guaranteed in every generated
structure (e.g. ``{"C": 2, "N": 1}``). The required atoms are
placed first; remaining slots are filled by weighted random sampling.
``None`` (default) → no lower bounds. The sum of all minimum counts
must not exceed ``n_atoms``.
element_max_counts:
Maximum number of atoms allowed per element
(e.g. ``{"N": 5, "O": 3}``). Elements that have reached their
cap are excluded from sampling for the remaining slots.
``None`` (default) → no upper bounds.
.. note::
When both *element_min_counts* and *element_max_counts* are
given, each element's min must be ≤ its max.
.. note::
The automatic hydrogen augmentation step (``add_hydrogen=True``)
runs *after* the constrained sampling and may temporarily exceed
*element_max_counts* for H. Set ``add_hydrogen=False`` if H
count limits are critical.
cov_scale:
Minimum-distance scale factor: ``d_min(i,j) = cov_scale × (r_i + r_j)``
using Pyykkö (2009) single-bond covalent radii. Default: ``1.0``.
relax_cycles:
Maximum repulsion-relaxation iterations (default: 1500).
add_hydrogen:
Automatically append H atoms when H is in the pool but the sampled
composition contains none (default: ``True``).
n_samples:
Maximum number of placement attempts (default: 1).
Use ``0`` to allow unlimited attempts (only valid when *n_success*
is also set, otherwise a :exc:`ValueError` is raised).
n_success:
Target number of structures that must pass all filters before
generation stops (default: ``None``).
- ``None`` → generate exactly *n_samples* attempts and return all
that passed (original behaviour).
- ``N > 0`` with ``n_samples > 0`` → stop as soon as *N* structures
pass **or** *n_samples* attempts are exhausted, whichever comes
first. Returns the structures collected so far with a warning if
fewer than *N* were found.
- ``N > 0`` with ``n_samples = 0`` → unlimited attempts; stop only
when *N* structures have passed.
seed:
Random seed for reproducibility (``None`` → non-deterministic).
n_bins:
Histogram bins for ``H_spatial`` and ``RDF_dev`` (default: 20).
w_atom:
Weight of ``H_atom`` in ``H_total`` (default: 0.5).
w_spatial:
Weight of ``H_spatial`` in ``H_total`` (default: 0.5).
cutoff:
Distance cutoff in Å for Steinhardt and graph metrics.
``None`` → auto-computed as ``cov_scale × 1.5 × median(r_i + r_j)``
over the element pool.
filters:
Filter strings of the form ``"METRIC:MIN:MAX"`` (use ``"-"`` for an
open bound). Only structures satisfying *all* filters are returned.
verbose:
Print progress and statistics to *stderr* (default: ``False``).
The CLI always passes ``True``; library callers usually leave it off.
Examples
--------
Class API::
from pasted import StructureGenerator
gen = StructureGenerator(
n_atoms=12, charge=0, mult=1,
mode="gas", region="sphere:9",
elements="1-30", n_samples=50, seed=42,
filters=["H_total:2.0:-"],
)
structures = gen.generate()
for s in structures:
print(s)
Functional API::
from pasted import generate
structures = generate(
n_atoms=12, charge=0, mult=1,
mode="chain", elements="6,7,8",
n_samples=20, seed=0,
)
"""
def __init__(
self,
*,
n_atoms: int,
charge: int,
mult: int,
mode: str = "gas",
region: str | None = None,
branch_prob: float = 0.3,
chain_persist: float = 0.5,
chain_bias: float = 0.0,
bond_range: tuple[float, float] = (1.2, 1.6),
center_z: int | None = None,
coord_range: tuple[int, int] = (4, 8),
shell_radius: tuple[float, float] = (1.8, 2.5),
elements: str | list[str] | None = None,
element_fractions: dict[str, float] | None = None,
element_min_counts: dict[str, int] | None = None,
element_max_counts: dict[str, int] | None = None,
cov_scale: float = 1.0,
relax_cycles: int = 1500,
maxent_steps: int = 300,
maxent_lr: float = 0.05,
maxent_cutoff_scale: float = 2.5,
trust_radius: float = 0.5,
convergence_tol: float = 1e-3,
add_hydrogen: bool = True,
n_samples: int = 1,
n_success: int | None = None,
seed: int | None = None,
n_bins: int = 20,
w_atom: float = 0.5,
w_spatial: float = 0.5,
cutoff: float | None = None,
filters: list[str] | None = None,
verbose: bool = False,
) -> None:
if mode not in ("gas", "chain", "shell", "maxent"):
raise ValueError(
f"mode must be 'gas', 'chain', 'shell', or 'maxent'; got {mode!r}"
)
if mode in ("gas", "maxent") and region is None:
raise ValueError("region is required when mode='gas' or mode='maxent'")
self.n_atoms = n_atoms
self.charge = charge
self.mult = mult
self.mode = mode
self.region = region
self.branch_prob = branch_prob
self.chain_persist = chain_persist
self.chain_bias = chain_bias
self.bond_range = bond_range
self.center_z = center_z
self.coord_range = coord_range
self.shell_radius = shell_radius
self.cov_scale = cov_scale
self.relax_cycles = relax_cycles
self.maxent_steps = maxent_steps
self.maxent_lr = maxent_lr
self.maxent_cutoff_scale = maxent_cutoff_scale
self.trust_radius = trust_radius
self.convergence_tol = convergence_tol
self._add_hydrogen = add_hydrogen
self.n_samples = n_samples
self.n_success = n_success
self.seed = seed
self.n_bins = n_bins
self.w_atom = w_atom
self.w_spatial = w_spatial
self.verbose = verbose
# ── n_samples / n_success validation ────────────────────────────
if n_samples == 0 and n_success is None:
raise ValueError(
"n_samples=0 (unlimited) requires n_success to be set; "
"otherwise generation would run forever."
)
if n_success is not None and n_success < 1:
raise ValueError(f"n_success must be >= 1; got {n_success}.")
# ── Element pool ────────────────────────────────────────────────
if elements is None:
self._element_pool: list[str] = default_element_pool()
elif isinstance(elements, str):
self._element_pool = parse_element_spec(elements)
else:
self._element_pool = list(elements)
# ── Element fractions ────────────────────────────────────────────
# Build a normalised weight list aligned with self._element_pool.
# Unknown keys in the dict raise ValueError immediately.
if element_fractions is not None:
unknown = set(element_fractions) - set(self._element_pool)
if unknown:
raise ValueError(
f"element_fractions contains symbols not in the element pool: "
f"{sorted(unknown)}"
)
weights = [float(element_fractions.get(sym, 1.0)) for sym in self._element_pool]
if any(w < 0 for w in weights):
raise ValueError("element_fractions weights must be non-negative.")
total = sum(weights)
if total == 0:
raise ValueError("element_fractions weights must not all be zero.")
self._element_weights: list[float] = [w / total for w in weights]
else:
n = len(self._element_pool)
self._element_weights = [1.0 / n] * n
# ── Element min/max counts ───────────────────────────────────────
# Validate and store; actual enforcement happens in stream().
if element_min_counts is not None:
unknown_min = set(element_min_counts) - set(self._element_pool)
if unknown_min:
raise ValueError(
f"element_min_counts contains symbols not in the element pool: "
f"{sorted(unknown_min)}"
)
if any(v < 0 for v in element_min_counts.values()):
raise ValueError("element_min_counts values must be non-negative.")
total_min = sum(element_min_counts.values())
if total_min > n_atoms:
raise ValueError(
f"Sum of element_min_counts ({total_min}) exceeds n_atoms ({n_atoms})."
)
if element_max_counts is not None:
unknown_max = set(element_max_counts) - set(self._element_pool)
if unknown_max:
raise ValueError(
f"element_max_counts contains symbols not in the element pool: "
f"{sorted(unknown_max)}"
)
if any(v < 0 for v in element_max_counts.values()):
raise ValueError("element_max_counts values must be non-negative.")
if element_min_counts is not None and element_max_counts is not None:
for sym in element_min_counts:
lo = element_min_counts[sym]
hi = element_max_counts.get(sym, lo)
if lo > hi:
raise ValueError(
f"element_min_counts[{sym!r}]={lo} > "
f"element_max_counts[{sym!r}]={hi}."
)
self._element_min_counts: dict[str, int] = dict(element_min_counts or {})
self._element_max_counts: dict[str, int] = dict(element_max_counts or {})
# ── Filters ─────────────────────────────────────────────────────
self._filters: list[tuple[str, float, float]] = [parse_filter(f) for f in (filters or [])]
# ── Cutoff ──────────────────────────────────────────────────────
self._cutoff: float = self._resolve_cutoff(cutoff)
# ── Shell center ─────────────────────────────────────────────────
self._fixed_center_sym: str | None = None
if mode == "shell" and center_z is not None:
if center_z not in _Z_TO_SYM:
raise ValueError(f"center_z={center_z}: unknown atomic number.")
sym = _Z_TO_SYM[center_z]
if sym not in self._element_pool:
raise ValueError(f"center_z={center_z} ({sym}) is not in the element pool.")
self._fixed_center_sym = sym
if self.verbose:
self._log(f"[pool] {len(self._element_pool)} elements in pool")
if mode == "shell":
if self._fixed_center_sym:
self._log(
f"[shell] center fixed: {self._fixed_center_sym} "
f"(Z={ATOMIC_NUMBERS[self._fixed_center_sym]})"
)
else:
self._log("[shell] center: random per sample (chaos mode)")
# ------------------------------------------------------------------ #
# Internal helpers #
# ------------------------------------------------------------------ #
def _log(self, msg: str) -> None:
"""Print *msg* to stderr when verbose mode is active."""
print(msg, file=sys.stderr)
def _resolve_cutoff(self, override: float | None) -> float:
if override is not None:
if self.verbose:
self._log(f"[cutoff] {override:.3f} Å (user-specified)")
return override
radii = [_cov_radius_ang(s) for s in self._element_pool]
pair_sums = sorted(ra + rb for i, ra in enumerate(radii) for rb in radii[i:])
median_sum = pair_sums[len(pair_sums) // 2]
cutoff = self.cov_scale * 1.5 * median_sum
if self.verbose:
self._log(
f"[cutoff] {cutoff:.3f} Å (auto: cov_scale={self.cov_scale} × 1.5 × "
f"median(r_i+r_j)={median_sum:.3f} Å)"
)
return cutoff
def _sample_atoms(self, rng: random.Random) -> list[str]:
"""Sample *n_atoms* element symbols respecting fractions and count bounds.
Algorithm
---------
1. If no fractions/min/max are configured, falls back to the
original uniform ``rng.choice`` per atom (preserves seed parity).
2. Otherwise: place the guaranteed minimum-count atoms first
(``element_min_counts``), fill remaining slots by weighted random
sampling (``element_fractions``), excluding elements that have
reached their ``element_max_counts`` cap, then shuffle.
Raises
------
RuntimeError
When the constraints cannot be satisfied (e.g. all remaining
elements are capped and there are still slots to fill).
"""
pool = self._element_pool
min_c = self._element_min_counts
max_c = self._element_max_counts
n = len(pool)
uniform = (n > 0 and all(abs(w - 1.0 / n) < 1e-12 for w in self._element_weights))
# Fast path: uniform weights, no bounds → identical to original behaviour
if uniform and not min_c and not max_c:
return [rng.choice(pool) for _ in range(self.n_atoms)]
weights = self._element_weights
# ── Step 1: fill guaranteed minimum counts ──────────────────────
counts: dict[str, int] = {sym: min_c.get(sym, 0) for sym in pool}
atoms: list[str] = []
for sym in pool:
atoms.extend([sym] * counts[sym])
remaining = self.n_atoms - len(atoms)
# ── Step 2: weighted sampling for remaining slots ────────────────
for _ in range(remaining):
# Build eligible pool (not yet capped)
eligible: list[str] = []
eligible_w: list[float] = []
for sym, w in zip(pool, weights, strict=True):
cap = max_c.get(sym, None)
if cap is None or counts.get(sym, 0) < cap:
eligible.append(sym)
eligible_w.append(w)
if not eligible:
raise RuntimeError(
"element_max_counts constraints cannot be satisfied: "
"all elements are capped before n_atoms is reached."
)
# Normalise eligible weights and do a weighted choice
total_w = sum(eligible_w)
cum: list[float] = []
acc = 0.0
for w in eligible_w:
acc += w / total_w
cum.append(acc)
r = rng.random()
chosen = eligible[-1]
for sym, c in zip(eligible, cum, strict=True):
if r <= c:
chosen = sym
break
counts[chosen] = counts.get(chosen, 0) + 1
atoms.append(chosen)
# ── Step 3: shuffle so forced atoms don't cluster at front ───────
rng.shuffle(atoms)
return atoms
# ------------------------------------------------------------------ #
# Public properties #
# ------------------------------------------------------------------ #
@property
def element_pool(self) -> list[str]:
"""A copy of the resolved element pool (list of symbols)."""
return list(self._element_pool)
@property
def cutoff(self) -> float:
"""Distance cutoff in Å used for Steinhardt and graph metrics."""
return self._cutoff
# ------------------------------------------------------------------ #
# Generation #
# ------------------------------------------------------------------ #
# ------------------------------------------------------------------ #
# Internal placement dispatch #
# ------------------------------------------------------------------ #
def _place_one(
self,
atoms_list: list[str],
rng: random.Random,
) -> tuple[list[str], list[Vec3], str | None]:
"""Run the mode-specific placement and return (atoms, positions, center_sym).
Raises
------
RuntimeError, ValueError
Propagated from the underlying placement functions.
"""
bond_lo, bond_hi = self.bond_range
shell_lo, shell_hi = self.shell_radius
coord_lo, coord_hi = self.coord_range
center_sym: str | None = None
if self.mode == "gas":
assert self.region is not None # guaranteed by __init__ validation
atoms_out, positions = place_gas(
atoms_list,
self.region,
rng,
)
elif self.mode == "chain":
atoms_out, positions = place_chain(
atoms_list,
bond_lo,
bond_hi,
self.branch_prob,
self.chain_persist,
rng,
chain_bias=self.chain_bias,
)
elif self.mode == "maxent":
assert self.region is not None
atoms_out, positions = place_maxent(
atoms_list,
self.region,
self.cov_scale,
rng,
maxent_steps=self.maxent_steps,
maxent_lr=self.maxent_lr,
maxent_cutoff_scale=self.maxent_cutoff_scale,
trust_radius=self.trust_radius,
convergence_tol=self.convergence_tol,
seed=self.seed,
)
else: # shell
center_sym = (
self._fixed_center_sym
if self._fixed_center_sym is not None
else rng.choice(atoms_list)
)
atoms_out, positions = place_shell(
atoms_list,
center_sym,
coord_lo,
coord_hi,
shell_lo,
shell_hi,
bond_lo,
bond_hi,
rng,
)
return atoms_out, positions, center_sym
# ------------------------------------------------------------------ #
# Generation #
# ------------------------------------------------------------------ #
[docs]
def stream(self) -> Iterator[Structure]:
"""Generate structures one by one, yielding each that passes all filters.
Unlike :meth:`generate`, structures are yielded immediately as they
pass, so callers can write output or stop early without waiting for
all attempts to complete.
Respects both *n_samples* (maximum attempts) and *n_success* (target
number of passing structures):
- If *n_success* is set, the iterator stops as soon as that many
structures have been yielded — even if *n_samples* attempts have
not been exhausted.
- If *n_samples* is ``0`` (unlimited), the iterator runs until
*n_success* structures have been yielded.
- If *n_samples* attempts are exhausted before *n_success* is
reached, a warning is emitted to *stderr* and the iterator ends.
Each call creates a fresh :class:`random.Random` seeded with
``self.seed``, so repeated calls with the same seed are reproducible.
Yields
------
Structure
Each structure that passed all filters, in generation order.
Examples
--------
Write structures to a file as they are found::
gen = StructureGenerator(
n_atoms=12, charge=0, mult=1,
mode="gas", region="sphere:9",
elements="1-30", n_success=10, n_samples=500, seed=42,
)
for s in gen.stream():
s.write_xyz("out.xyz")
"""
rng = random.Random(self.seed)
if self.verbose and self._filters:
self._log(
"[filter] "
+ ", ".join(f"{m} in [{lo:.4g},{hi:.4g}]" for m, lo, hi in self._filters)
)
do_add_h = ("H" in self._element_pool) and self._add_hydrogen
n_passed = n_invalid = n_attempted = n_rejected_filter = 0
unlimited = (self.n_samples == 0)
denom = "∞" if unlimited else str(self.n_samples)
width = len(denom)
while True:
# Stop conditions
if not unlimited and n_attempted >= self.n_samples:
break
if self.n_success is not None and n_passed >= self.n_success:
break
i = n_attempted
n_attempted += 1
atoms_list = self._sample_atoms(rng)
if do_add_h:
atoms_list = add_hydrogen(atoms_list, rng)
ok, val_msg = validate_charge_mult(atoms_list, self.charge, self.mult)
if not ok:
n_invalid += 1
if self.verbose:
self._log(f"[{i + 1:>{width}}/{denom}:invalid] {val_msg}")
continue
try:
atoms_out, positions, center_sym = self._place_one(atoms_list, rng)
except (RuntimeError, ValueError) as exc:
if self.verbose:
self._log(f"[ERROR] sample {i + 1}: {exc}")
raise
positions, converged = relax_positions(
atoms_out, positions, self.cov_scale, self.relax_cycles, seed=self.seed
)
if not converged and self.verbose:
self._log(
f"[{i + 1:>{width}}/{denom}:warn] "
f"relax_positions did not converge in {self.relax_cycles} cycles."
)
metrics = compute_all_metrics(
atoms_out,
positions,
self.n_bins,
self.w_atom,
self.w_spatial,
self._cutoff,
self.cov_scale,
)
passed = passes_filters(metrics, self._filters)
if self.verbose:
flag = "PASS" if passed else "skip"
self._log(
f"[{i + 1:>{width}}/{denom}:{flag}] "
+ " ".join(f"{k}={_fmt(v)}" for k, v in metrics.items())
)
if not passed:
n_rejected_filter += 1
continue
n_passed += 1
yield Structure(
atoms=atoms_out,
positions=positions,
charge=self.charge,
mult=self.mult,
metrics=metrics,
mode=self.mode,
sample_index=n_passed,
center_sym=center_sym if self.mode == "shell" else None,
seed=self.seed,
)
n_skip = n_attempted - n_passed - n_invalid
if self.verbose:
self._log(
f"[summary] attempted={n_attempted} passed={n_passed} "
f"rejected_parity={n_invalid} rejected_filter={n_skip}"
)
# ── warnings.warn for noteworthy outcomes ─────────────────────────
# Fires regardless of verbose so that downstream consumers
# (ASE, HT pipelines) receive machine-visible signals even when
# PASTED is not in verbose mode.
#
# Parity warnings fire only when n_passed == 0 (complete failure).
# Partial parity rejection where some structures still passed is
# expected behaviour for mixed-element pools and does not require
# a warning — the verbose summary line already reports the counts.
if n_invalid > 0 and n_passed == 0:
warnings.warn(
f"All {n_attempted} attempt(s) were rejected by the charge/"
f"multiplicity parity check ({n_invalid} invalid). "
f"No structures were generated. "
f"Check that your element pool can satisfy "
f"charge={self.charge}, mult={self.mult}.",
UserWarning,
stacklevel=4,
)
if n_passed == 0 and n_invalid == 0:
warnings.warn(
f"No structures passed the metric filters after "
f"{n_attempted} attempt(s) "
f"({n_skip} rejected by filters). "
f"Try relaxing the --filter thresholds or increasing n_samples.",
UserWarning,
stacklevel=4,
)
elif (
self.n_success is not None
and n_passed < self.n_success
and not unlimited
):
warnings.warn(
f"Attempt budget exhausted ({n_attempted} attempts) before "
f"reaching n_success={self.n_success}; "
f"only {n_passed} structure(s) collected. "
f"Increase n_samples or relax filters.",
UserWarning,
stacklevel=4,
)
# Store run statistics so generate() can build a GenerationResult
# without re-running the generator loop.
self._last_run_stats: dict[str, int] = {
"n_attempted": n_attempted,
"n_passed": n_passed,
"n_rejected_parity": n_invalid,
"n_rejected_filter": n_rejected_filter,
}
[docs]
def generate(self) -> GenerationResult:
"""Generate structures and return a :class:`GenerationResult`.
Collects all structures yielded by :meth:`stream`, attaches
generation metadata (attempt counts, rejection breakdowns), and
returns a :class:`GenerationResult` that behaves like a
``list[Structure]`` in all normal usage while also carrying the
diagnostics needed for automated pipelines.
:class:`GenerationResult` supports the full ``list`` interface
(indexing, iteration, ``len``, ``bool``) so existing code that
does ``result[0]`` or ``for s in result`` continues to work
without modification.
Warnings are also emitted via :func:`warnings.warn` (category
:class:`UserWarning`) when:
- Any attempts are rejected by the charge/multiplicity parity check.
- No structures pass the metric filters.
- The attempt budget is exhausted before ``n_success`` is reached.
Each call creates a fresh :class:`random.Random` seeded with
``self.seed``, so repeated calls with the same seed are
reproducible.
Returns
-------
GenerationResult
Wraps the list of passing structures together with generation
metadata. Use ``result.structures`` for the raw list or
``result.summary()`` for a one-line diagnostic string.
Examples
--------
Drop-in list usage::
result = gen.generate()
for s in result:
print(s.to_xyz())
Metadata access::
result = gen.generate()
if result.n_rejected_parity > 0:
print(result.summary())
"""
structures = list(self.stream())
stats: dict[str, int] = getattr(self, "_last_run_stats", {})
return GenerationResult(
structures=structures,
n_attempted=stats.get("n_attempted", len(structures)),
n_passed=stats.get("n_passed", len(structures)),
n_rejected_parity=stats.get("n_rejected_parity", 0),
n_rejected_filter=stats.get("n_rejected_filter", 0),
n_success_target=self.n_success,
)
# ------------------------------------------------------------------ #
# Iteration support #
# ------------------------------------------------------------------ #
[docs]
def __iter__(self) -> Iterator[Structure]:
"""Iterate over generated structures (delegates to :meth:`stream`)."""
return self.stream()
[docs]
def __repr__(self) -> str:
return (
f"StructureGenerator("
f"n_atoms={self.n_atoms}, mode={self.mode!r}, "
f"charge={self.charge:+d}, mult={self.mult}, "
f"n_samples={self.n_samples}, "
f"n_success={self.n_success}, "
f"pool_size={len(self._element_pool)})"
)
# ---------------------------------------------------------------------------
# Functional API
# ---------------------------------------------------------------------------
def read_xyz(
source: str | Path,
*,
recompute_metrics: bool = True,
cutoff: float | None = None,
n_bins: int = 20,
w_atom: float = 0.5,
w_spatial: float = 0.5,
cov_scale: float = 1.0,
) -> list[Structure]:
"""Read one or more structures from an XYZ file or string.
Convenience wrapper around :meth:`Structure.from_xyz` that reads **all
frames** from a (possibly multi-frame) XYZ source and returns them as a
list. Both plain XYZ and PASTED extended XYZ are supported.
Parameters
----------
source:
Path to an XYZ file **or** a raw XYZ string.
recompute_metrics:
Recompute all disorder metrics after loading each structure
(default: ``True``).
cutoff:
Distance cutoff (Å) for metric computation. Auto-computed from
each structure's element pool when ``None``.
n_bins:
Histogram bins for ``H_spatial`` / ``RDF_dev`` (default: 20).
w_atom:
Weight of ``H_atom`` in ``H_total`` (default: 0.5).
w_spatial:
Weight of ``H_spatial`` in ``H_total`` (default: 0.5).
cov_scale:
Minimum distance scale factor used for metrics (default: 1.0).
Returns
-------
list[Structure]
One :class:`Structure` per frame, in file order.
Examples
--------
Load a PASTED output file and pass the first structure to the
optimizer::
from pasted import read_xyz, StructureOptimizer
structs = read_xyz("results.xyz")
opt = StructureOptimizer(
n_atoms=len(structs[0]),
charge=structs[0].charge,
mult=structs[0].mult,
objective={"H_total": 1.0},
elements=list(set(structs[0].atoms)),
max_steps=3000,
seed=42,
)
result = opt.run(initial=structs[0])
Compose with :class:`GenerationResult` via ``+``::
from pasted import read_xyz, generate
existing = generate(n_atoms=10, charge=0, mult=1,
mode="gas", region="sphere:9",
elements="6,7,8", n_samples=5, seed=0)
loaded = read_xyz("previous_run.xyz")
# loaded is a list[Structure]; wrap manually if needed:
from pasted import GenerationResult
all_structs = existing + GenerationResult(structures=loaded,
n_passed=len(loaded),
n_attempted=len(loaded))
"""
p = Path(source) if not isinstance(source, str) or "\n" not in str(source) else None
text = p.read_text() if (p is not None and p.exists()) else str(source)
frames = parse_xyz(text)
result: list[Structure] = []
for atoms, positions, charge, mult, embedded_metrics in frames:
if recompute_metrics:
cut = cutoff
if cut is None:
radii = [_cov_radius_ang(a) for a in atoms]
pair_sums = sorted(
ra + rb for i, ra in enumerate(radii) for rb in radii[i:]
)
median_sum = pair_sums[len(pair_sums) // 2]
cut = cov_scale * 1.5 * median_sum
metrics = compute_all_metrics(
atoms, positions, n_bins, w_atom, w_spatial, cut, cov_scale
)
else:
metrics = embedded_metrics
result.append(Structure(
atoms=list(atoms),
positions=list(positions),
charge=charge,
mult=mult,
metrics=metrics,
mode="loaded_xyz",
))
return result
[docs]
def generate(
*,
n_atoms: int,
charge: int,
mult: int,
mode: str = "gas",
region: str | None = None,
branch_prob: float = 0.3,
chain_persist: float = 0.5,
chain_bias: float = 0.0,
bond_range: tuple[float, float] = (1.2, 1.6),
center_z: int | None = None,
coord_range: tuple[int, int] = (4, 8),
shell_radius: tuple[float, float] = (1.8, 2.5),
elements: str | list[str] | None = None,
element_fractions: dict[str, float] | None = None,
element_min_counts: dict[str, int] | None = None,
element_max_counts: dict[str, int] | None = None,
cov_scale: float = 1.0,
relax_cycles: int = 1500,
maxent_steps: int = 300,
maxent_lr: float = 0.05,
maxent_cutoff_scale: float = 2.5,
trust_radius: float = 0.5,
convergence_tol: float = 1e-3,
add_hydrogen: bool = True,
n_samples: int = 1,
n_success: int | None = None,
seed: int | None = None,
n_bins: int = 20,
w_atom: float = 0.5,
w_spatial: float = 0.5,
cutoff: float | None = None,
filters: list[str] | None = None,
verbose: bool = False,
) -> GenerationResult:
"""Create a :class:`StructureGenerator` and immediately call
:meth:`~StructureGenerator.generate`.
All parameters are forwarded unchanged. See :class:`StructureGenerator`
for full documentation.
Returns
-------
GenerationResult
A list-compatible object containing the structures that passed all
filters plus metadata about the generation run (attempt counts,
rejection breakdowns). Behaves identically to ``list[Structure]``
in all normal usage (indexing, iteration, ``len``, ``bool``).
:class:`UserWarning` is raised whenever:
- attempts are rejected by the charge/multiplicity parity check,
- no structures pass the metric filters, or
- the attempt budget is exhausted before ``n_success`` is reached.
Examples
--------
Drop-in list usage::
from pasted import generate
# 20 random gas-phase structures drawn from C/N/O
structures = generate(
n_atoms=10, charge=0, mult=1,
mode="gas", region="sphere:8",
elements="6,7,8", n_samples=20, seed=0,
)
for i, s in enumerate(structures):
s.write_xyz("out.xyz", append=(i > 0))
Inspecting rejection metadata::
result = generate(
n_atoms=10, charge=0, mult=1,
mode="gas", region="sphere:8",
elements="6,7,8", n_samples=50, seed=0,
filters=["H_total:1.5:-"],
)
print(result.summary())
# e.g. "passed=3 attempted=50 rejected_parity=0 rejected_filter=47"
"""
gen = StructureGenerator(
n_atoms=n_atoms,
charge=charge,
mult=mult,
mode=mode,
region=region,
branch_prob=branch_prob,
chain_persist=chain_persist,
chain_bias=chain_bias,
bond_range=bond_range,
center_z=center_z,
coord_range=coord_range,
shell_radius=shell_radius,
elements=elements,
element_fractions=element_fractions,
element_min_counts=element_min_counts,
element_max_counts=element_max_counts,
cov_scale=cov_scale,
relax_cycles=relax_cycles,
maxent_steps=maxent_steps,
maxent_lr=maxent_lr,
maxent_cutoff_scale=maxent_cutoff_scale,
trust_radius=trust_radius,
convergence_tol=convergence_tol,
add_hydrogen=add_hydrogen,
n_samples=n_samples,
n_success=n_success,
seed=seed,
n_bins=n_bins,
w_atom=w_atom,
w_spatial=w_spatial,
cutoff=cutoff,
filters=filters,
verbose=verbose,
)
return gen.generate()