Source code for attune._holistic

"""Function for processing multi-dependent tuning data."""

import itertools

import numpy as np
import scipy

import WrightTools as wt
from ._instrument import Instrument
from ._transition import Transition
from ._plot import plot_holistic
from ._common import save


__all__ = ["holistic"]


def _holistic(data, amplitudes, centers, arrangement):
    points = np.array([np.broadcast_to(a[:], amplitudes.shape).flatten() for a in data.axes]).T
    ndim = len(data.axes)
    delaunay = scipy.spatial.Delaunay(points)

    amp_interp = scipy.interpolate.LinearNDInterpolator(delaunay, amplitudes.points.flatten())
    cen_interp = scipy.interpolate.LinearNDInterpolator(delaunay, centers.points.flatten())

    # def
    out_points = []
    for p in arrangement.independent:
        iso_points = []
        for s, pts, vals in _find_simplices_containing(delaunay, cen_interp, p):
            iso_points.extend(_edge_intersections(pts, vals, p))
        iso_points = np.array(iso_points)
        if len(iso_points) > 3:
            out_points.append(
                tuple(_fit_gauss(iso_points.T[i], amp_interp(iso_points)) for i in range(ndim))
            )
        else:
            out_points.append(tuple(np.nan for i in range(ndim)))

    return np.array(out_points)


[docs] def holistic( *, data, channels, arrangement, tunes, instrument, spectral_axis=-1, level=False, gtol=0.01, autosave=True, save_directory=None, **spline_kwargs, ): """Workup multi-dependent tuning data. Note: At this time, this function expects 2-dimensional motor space. The algorithm should generalize to N-dimensional motor space, however this is untested and plotting likely will fail. Parameters ---------- data: WrightTools.Data The data object to process. channels: WrightTools.data.Channel or int or str or 2-tuple If singular: the spectral axis, from which the 0th and 1st moments will be taken to obtain amplitudes and centers. In this case, `spectral_axis` determines which axis is used to obtain the moments. If a tuple: (amplitudes, centers), then these channels will be used directly. tunes: iterable of str Names of the tunes to modify in the instrument, in the same order as the axes of `data`. Must not be DiscreteTunes. instrument: attune.Instrument Instrument object to modify. Setpoints are determined from the instrument. Keyword Parameters ------------------ spectral_axis: WrightTools.data.Axis or int or str (default -1) The axis along which to take moments. Only applies if a single channel is given. level: bool (default False) Toggle leveling data. If two channels are given, only the amplitudes are leveled. If a single channel is given, leveling occurs before taking the moments. gtol: float (default 0.01) Global tolerance for rejecting noise level relative to the global maximum. autosave: bool (default True) Toggles saving of instrument files and images. save_directory: Path-like (Defaults to current working directory) Specify where to save files. **spline_kwargs: Extra arguments to pass to spline creation (e.g. s=0, k=1 for linear interpolation) """ metadata = { "channels": channels, "arrangement": arrangement, "tunes": tunes, "spectral_axis": spectral_axis, "level": level, "gtol": gtol, "spline_kwargs": spline_kwargs, } if not isinstance(channels, (int, str)): try: metadata["channels"] = list(channels) if not isinstance(channels[0], (int, str)): metadata["channels"][0] = channels[0].natural_name if not isinstance(channels[1], (int, str)): metadata["channels"][1] = channels[1].natural_name except TypeError: metadata["channels"] = channel.natural_name transition = Transition("holistic", instrument, metadata=metadata, data=data) # collect data = data.copy() if isinstance(channels, (str, wt.data.Channel)): if level: data.level(channels, 0, -3) if isinstance(spectral_axis, int): spectral_axis = data.axis_names[spectral_axis] elif isinstance(spectral_axis, wt.data.Axis): spectral_axis = spectral_axis.expression getattr(data, spectral_axis).convert("nm") # take channel moments data.moment( axis=spectral_axis, channel=channels, resultant=wt.kit.joint_shape(*[a for a in data.axes if a.expression != spectral_axis]), moment=0, ) data.moment( axis=spectral_axis, channel=channels, resultant=wt.kit.joint_shape(*[a for a in data.axes if a.expression != spectral_axis]), moment=1, ) amplitudes = data.channels[-2] centers = data.channels[-1] data.transform(*[a for a in data.axis_expressions if a != spectral_axis]) else: amplitudes, centers = channels if isinstance(amplitudes, (int, str)): amplitudes = data.channels[wt.kit.get_index(data.channel_names, amplitudes)] if isinstance(centers, (int, str)): centers = data.channels[wt.kit.get_index(data.channel_names, centers)] if level: data.level(amplitudes.natural_name, 0, -3) if gtol is not None: cutoff = amplitudes.max() * gtol amplitudes.clip(min=cutoff) centers[np.isnan(amplitudes)] = np.nan out_points = _holistic(data, amplitudes, centers, instrument[arrangement]) splines = [ wt.kit.Spline(instrument[arrangement].independent, vals, **spline_kwargs) for vals in out_points.T ] new_instrument = _gen_instr(instrument, arrangement, tunes, splines, transition) fig, _ = plot_holistic( data, amplitudes.natural_name, centers.natural_name, arrangement, tunes, new_instrument, instrument, out_points, ) if autosave: save(new_instrument, fig, "holistic", save_directory) return new_instrument
def _gen_instr(instrument, arrangement, tunes, splines, transition): new_instrument = instrument.as_dict() del new_instrument["transition"] setpoints = instrument[arrangement].independent for tune, spline in zip(tunes, splines): new_instrument["arrangements"][arrangement]["tunes"][tune]["independent"] = setpoints new_instrument["arrangements"][arrangement]["tunes"][tune]["dependent"] = spline(setpoints) return Instrument(**new_instrument, transition=transition) def _find_simplices_containing(delaunay, interpolator, point): for s in delaunay.simplices: extrema = interpolator([p for p in delaunay.points[s]]) if min(extrema) < point <= max(extrema): yield s, delaunay.points[s], extrema def _edge_intersections(points, evaluated, target): sortord = np.argsort(evaluated) evaluated = evaluated[sortord] points = points[sortord] for (p1, p2), (v1, v2) in zip( itertools.combinations(points, 2), itertools.combinations(evaluated, 2) ): if v1 < target <= v2: yield tuple( p1[i] + (p2[i] - p1[i]) * ((target - v1) / (v2 - v1)) for i in range(len(p1)) ) def _fit_gauss(x, y): x, y = wt.kit.remove_nans_1D(x, y) def resid(inps): nonlocal x, y return y - _gauss(*inps)(x) bounds = [(-np.inf, np.inf) for i in range(3)] x_range = np.max(x) - np.min(x) bounds[0] = (np.min(x) - x_range / 10, np.max(x) + x_range / 10) bounds = np.array(bounds).T x0 = [np.median(x), x_range / 10, np.max(y)] opt = scipy.optimize.least_squares(resid, x0, bounds=bounds) return opt.x[0] def _gauss(center, sigma, amplitude): return lambda x: amplitude * np.exp(-1 / 2 * (x - center) ** 2 / sigma**2)