MADNESS 0.10.1
py_functor.h
Go to the documentation of this file.
1/*
2 pymadness - Python bindings for MADNESS
3 Adapter to use Python callables as MADNESS FunctionFunctorInterface.
4*/
5
6#ifndef PYMADNESS_PY_FUNCTOR_H
7#define PYMADNESS_PY_FUNCTOR_H
8
9#include <pybind11/pybind11.h>
10#include <pybind11/numpy.h>
11#include <pybind11/complex.h>
13#include <atomic>
14#include <stdexcept>
15#include <string>
16#include <thread>
17
18namespace py = pybind11;
19
20/// Wraps a Python callable as a MADNESS FunctionFunctorInterface<T,NDIM>.
21///
22/// Supports two calling conventions:
23/// - Scalar: f(r) where r has shape (NDIM,), returns scalar
24/// - Vectorized: f(r) where r has shape (npts, NDIM), returns array of shape (npts,)
25///
26/// The vectorized path is always used (supports_vectorized() returns true).
27/// On the first batch call, we probe the callable to detect which convention
28/// it uses and cache the result.
29template<typename T, std::size_t NDIM>
31 py::object py_callable_;
32 mutable std::atomic<int> vectorized_mode_{0}; // -1=probing, 0=unknown, 1=returns array, 2=returns scalar
33
34public:
35 explicit PyFunctor(py::object f) : py_callable_(std::move(f)) {}
36
38 // py::object's destructor decrements Python's refcount, which requires
39 // the GIL. MADNESS may destroy the shared_ptr<PyFunctor> from a worker
40 // thread that does not hold the GIL, so we must acquire it here.
41 py::gil_scoped_acquire gil;
42 py_callable_ = py::none();
43 }
44
45 T operator()(const madness::Vector<double, NDIM>& r) const override {
46 py::gil_scoped_acquire gil;
47
48 // Build a numpy array from the coordinate vector
49 py::array_t<double> arr(NDIM);
50 auto buf = arr.mutable_unchecked<1>();
51 for (std::size_t i = 0; i < NDIM; ++i) {
52 buf(i) = r[i];
53 }
54
55 py::object result = py_callable_(arr);
56 return result.cast<T>();
57 }
58
59 bool supports_vectorized() const override { return true; }
60
61private:
62 /// Core vectorized evaluation: builds numpy array from coordinate pointers,
63 /// calls the Python callable once, and fills fvals.
64 void eval_vectorized(const madness::Vector<double*, NDIM>& xvals, T* fvals, int npts) const {
65 py::gil_scoped_acquire gil;
66
67 // Build a numpy array of shape (npts, NDIM) from the coordinate arrays
68 py::array_t<double> coords({static_cast<py::ssize_t>(npts),
69 static_cast<py::ssize_t>(NDIM)});
70 auto cbuf = coords.mutable_unchecked<2>();
71 for (int i = 0; i < npts; ++i) {
72 for (std::size_t d = 0; d < NDIM; ++d) {
73 cbuf(i, d) = xvals[d][i];
74 }
75 }
76
77 // Probe on first call to determine if callable returns array or scalar.
78 // Note: eval_vectorized always holds the GIL (acquired above), so only one
79 // thread can execute this block at a time. The atomic CAS still ensures
80 // correct memory ordering across threads.
81 int expected = 0;
82 if (vectorized_mode_.compare_exchange_strong(expected, -1,
83 std::memory_order_acq_rel, std::memory_order_acquire)) {
84 // We won the probe; expected was 0, now set to -1 (in-progress sentinel)
85 int result_mode = 2; // default: fall back to per-point scalar loop
86 try {
87 py::object result = py_callable_(coords);
88 py::array_t<T> arr = result.cast<py::array_t<T>>();
89 if (arr.ndim() == 1 && arr.shape(0) == npts) {
90 result_mode = 1;
91 auto rbuf = arr.template unchecked<1>();
92 for (int i = 0; i < npts; ++i) {
93 fvals[i] = rbuf(i);
94 }
95 }
96 } catch (...) {
97 // Call failed or didn't return a proper array — fall through
98 // Clear any pending Python exception
99 PyErr_Clear();
100 }
101 vectorized_mode_.store(result_mode, std::memory_order_release);
102 if (result_mode == 2) {
103 eval_scalar_loop(coords, fvals, npts);
104 }
105 return;
106 }
107
108 // If another thread is probing (-1), we must release the GIL while
109 // waiting. The probing thread needs the GIL to finish its Python call,
110 // so spinning here with the GIL held would deadlock.
111 int mode = expected; // CAS failure: expected holds the current value
112 while (mode == -1) {
113 {
114 py::gil_scoped_release release;
115 std::this_thread::yield();
116 }
117 mode = vectorized_mode_.load(std::memory_order_acquire);
118 }
119
120 if (mode == 1) {
121 py::object result = py_callable_(coords);
122 py::array_t<T> arr = result.cast<py::array_t<T>>();
123 if (arr.ndim() != 1 || arr.shape(0) != npts) {
124 throw std::runtime_error(
125 "PyFunctor: vectorized callable returned array with wrong shape; "
126 "expected 1D array of length " + std::to_string(npts));
127 }
128 auto rbuf = arr.template unchecked<1>();
129 for (int i = 0; i < npts; ++i) {
130 fvals[i] = rbuf(i);
131 }
132 } else {
133 eval_scalar_loop(coords, fvals, npts);
134 }
135 }
136
137 /// Evaluate the callable per-point using scalar convention (still 1 GIL acquisition).
138 void eval_scalar_loop(const py::array_t<double>& coords, T* fvals, int npts) const {
139 auto cbuf = coords.unchecked<2>();
140 for (int i = 0; i < npts; ++i) {
141 py::array_t<double> pt(NDIM);
142 auto pbuf = pt.mutable_unchecked<1>();
143 for (std::size_t d = 0; d < NDIM; ++d) {
144 pbuf(d) = cbuf(i, d);
145 }
146 py::object result = py_callable_(pt);
147 fvals[i] = result.cast<T>();
148 }
149 }
150
151public:
152 // Vectorized operator() overrides for NDIM = 1..6
153 // These match the signatures in FunctionFunctorInterface (function_interface.h:99-121).
154 // We provide all six and only the one matching our NDIM will be called.
155
156 void operator()(const madness::Vector<double*, 1>& xvals, T* fvals, int npts) const override {
157 if constexpr (NDIM == 1) { eval_vectorized(xvals, fvals, npts); }
158 }
159
160 void operator()(const madness::Vector<double*, 2>& xvals, T* fvals, int npts) const override {
161 if constexpr (NDIM == 2) { eval_vectorized(xvals, fvals, npts); }
162 }
163
164 void operator()(const madness::Vector<double*, 3>& xvals, T* fvals, int npts) const override {
165 if constexpr (NDIM == 3) { eval_vectorized(xvals, fvals, npts); }
166 }
167
168 void operator()(const madness::Vector<double*, 4>& xvals, T* fvals, int npts) const override {
169 if constexpr (NDIM == 4) { eval_vectorized(xvals, fvals, npts); }
170 }
171
172 void operator()(const madness::Vector<double*, 5>& xvals, T* fvals, int npts) const override {
173 if constexpr (NDIM == 5) { eval_vectorized(xvals, fvals, npts); }
174 }
175
176 void operator()(const madness::Vector<double*, 6>& xvals, T* fvals, int npts) const override {
177 if constexpr (NDIM == 6) { eval_vectorized(xvals, fvals, npts); }
178 }
179};
180
181#endif // PYMADNESS_PY_FUNCTOR_H
Definition py_functor.h:30
bool supports_vectorized() const override
Does the interface support a vectorized operator()?
Definition py_functor.h:59
void operator()(const madness::Vector< double *, 5 > &xvals, T *fvals, int npts) const override
Definition py_functor.h:172
void operator()(const madness::Vector< double *, 2 > &xvals, T *fvals, int npts) const override
Definition py_functor.h:160
void operator()(const madness::Vector< double *, 4 > &xvals, T *fvals, int npts) const override
Definition py_functor.h:168
void eval_vectorized(const madness::Vector< double *, NDIM > &xvals, T *fvals, int npts) const
Definition py_functor.h:64
py::object py_callable_
Definition py_functor.h:31
std::atomic< int > vectorized_mode_
Definition py_functor.h:32
void operator()(const madness::Vector< double *, 1 > &xvals, T *fvals, int npts) const override
Definition py_functor.h:156
void operator()(const madness::Vector< double *, 3 > &xvals, T *fvals, int npts) const override
Definition py_functor.h:164
void eval_scalar_loop(const py::array_t< double > &coords, T *fvals, int npts) const
Evaluate the callable per-point using scalar convention (still 1 GIL acquisition).
Definition py_functor.h:138
~PyFunctor()
Definition py_functor.h:37
void operator()(const madness::Vector< double *, 6 > &xvals, T *fvals, int npts) const override
Definition py_functor.h:176
PyFunctor(py::object f)
Definition py_functor.h:35
T operator()(const madness::Vector< double, NDIM > &r) const override
You should implement this to return f(x)
Definition py_functor.h:45
Abstract base class interface required for functors used as input to Functions.
Definition function_interface.h:68
A simple, fixed dimension vector.
Definition vector.h:64
double(* f)(const coord_3d &)
Definition derivatives.cc:54
auto T(World &world, response_space &f) -> response_space
Definition global_functions.cc:28
Definition mraimpl.h:51
static const double d
Definition nonlinschro.cc:121
constexpr std::size_t NDIM
Definition testgconv.cc:54
const auto npts
Definition testgconv.cc:52