"""Base curve behavior."""

import re
import pathlib
import copy as copy_

import numpy as np

import matplotlib.pyplot as plt
import matplotlib.gridspec as grd

import WrightTools as wt
import tidy_headers
from ..interpolator import Linear, builtins
from ._dependent import Setpoints, Dependent

__all__ = ["Curve"]

# --- curve class ---------------------------------------------------------------------------------

[docs]class Curve: """Central object-type for all OPA tuning curves."""
[docs] def __init__( self, setpoints, dependents, name, interaction=None, kind="curve", method=Linear, subcurve=None, source_setpoints=None, fmt=None, **kwargs, ): """Create a ``Curve`` object. Parameters ---------- setpoints : attune.Setpoints The setpoint destinations for the curve. dependents : list of Dependent objects Dependent positions for each setpoint. name : str Name of curve. kind : string The kind of curve (for saving). method : interpolation class The interpolation method to use. """ # version from .. import __version__ self.__version__ = __version__ self.setpoints = setpoints if isinstance(dependents, Dependent): dependents = [dependents] if isinstance(dependents, dict): self.dependents = dependents else: self.dependents = { d for d in dependents} = name self.kind = kind self.subcurve = subcurve self.source_setpoints = source_setpoints self.interaction = interaction # set dependents as attributes of self for obj in self.dependents.values(): if len(obj) != len(self.setpoints): raise ValueError("Dependents must be the same length as setpoints") setattr(self,, obj) # initialize function object self.method = builtins.get(method, method) if fmt is None: fmt = ["%.2f"] + ["%.5f"] * len(self.dependents) self.fmt = fmt self.interpolate() for k, v in kwargs.items(): setattr(self, k, v)
def __add__(self, other): # copy self_ = self.copy() other = other.copy() # coerce other to own units other.convert(self_.setpoints.units) # find new control points other_limits = other.get_limits() self_limits = self_.get_limits() min_limit = max(other_limits[0], self_limits[0]) max_limit = min(other_limits[1], self_limits[1]) if min_limit > max_limit: raise ValueError("Curves must overlap") num_points = max(other.setpoints[:].size, self_.setpoints[:].size) new_setpoints = np.linspace(min_limit, max_limit, num_points) # coerce to new control points other.map_setpoints(new_setpoints) self_.map_setpoints(new_setpoints) # add self_keys = set(self_.dependents.keys()) other_keys = set(other.dependents.keys()) for k in self_keys | other_keys: if k in self_keys and k in other_keys: if wt.units.is_valid_conversion(other[k].units, self[k].units): other[k].convert(self[k].units) else: raise ValueError( f"Invalid unit conversion: {other[k].units} -> {self[k].units}" ) if self_[k].differential and other[k].differential: self_[k][:] += other[k][:] elif self_[k].differential or other[k].differential: self_[k][:] += other[k][:] self_[k].differential = False else: raise ValueError(f"Cannot add two Dependents which are both absolute: {k}") elif k in other_keys: self_.dependents[k] = copy_.deepcopy(other[k]) self_.interpolate() return self_ def __getitem__(self, key): if key in self.dependents: return self.dependents[key] return self.subcurve[key] def __setitem__(self, key, value): value = copy_.deepcopy(value) = key if value.interpolator is not None: value.positions = value(self.setpoints[:], self.setpoints.units) elif len(value) != len(self.setpoints): raise ValueError( f"Incorrect number of points in dependent: {len(value)} for number of setpoints: {len(self.setpoints)}" ) value.interpolator = self.method(self.setpoints, value) self.dependents[key] = value def __call__(self, value, units=None, full=True): return self.get_dependent_positions(value, units, full=full) @property def dependent_names(self): """Get dependent names. Parameters ---------- full : boolean (optional) Toggle inclusion of dependent names from subcurve. Returns ------- list of strings Dependent names. """ if self.subcurve: subcurve_dependent_names = self.subcurve.dependent_names else: subcurve_dependent_names = [] return [ for m in self.dependents.values()] + subcurve_dependent_names @property def dependent_units(self): """Get dependent names. Parameters ---------- full : boolean (optional) Toggle inclusion of dependent names from subcurve. Returns ------- list of strings Dependent names. """ if self.subcurve: subcurve_dependent_units = self.subcurve.dependent_units else: subcurve_dependent_units = [] return [m.units for m in self.dependents.values()] + subcurve_dependent_units def coerce_dependents(self): """Coerce the dependent positions to lie exactly along the interpolation positions. Can be thought of as 'smoothing' the curve. """ self.map_setpoints(self.setpoints[:], units="same") def convert(self, units, *, convert_dependents=False): """Convert the setpoints to new units. Parameters ---------- units : str The destination units. """ self.setpoints.convert(units) if self.subcurve: if wt.units.is_valid_conversion(self.source_setpoints.units, units): self.source_setpoints.convert(units) if convert_dependents: for d in self.dependents: if wt.units.is_valid_conversion(d.units, units): d.convert(units) self.interpolate() def copy(self): """Copy the curve object. Returns ------- curve A deep copy of the curve object. """ return copy_.deepcopy(self) def get_limits(self, units="same"): """Get the edges of the curve. Parameters ---------- units : str (optional) The units to return. Default is same. Returns ------- list of floats [min, max] in given units """ if units == "same": return [self.setpoints[:].min(), self.setpoints[:].max()] else: units_setpoints = wt.units.convert(self.setpoints[:], self.setpoints.units, units) return [units_setpoints.min(), units_setpoints.max()] def get_dependent_positions(self, setpoint, units="same", full=True): """Get the dependent positions for a destination setpoint. Parameters ---------- setpoint : number The destination setpoint. May be 1D array. units : str (optional) The units of the input setpoint. Returns ------- np.ndarray The dependent positions. If setpoint is an array the output shape will be (dependents, setpoints). """ out = {} for k, v in self.dependents.items(): out[k] = v(setpoint, units) if full and self.subcurve: out.update(self.subcurve(self.source_setpoints(setpoint, units), self.source_setpoints.units, full)) return out def get_source_setpoint(self, setpoint, units="same"): """Get setpoint of source curve. Parameters ---------- setpoint : number or 1D array Setpoint(s). units : string (optional) Units. Default is same. Returns ------- number or 1D array Source setpoint(s). """ if not self.subcurve: return None return self.source_setpoints(setpoint, units) def interpolate(self, interpolate_subcurve=True): """Generate the interploator object. Parameters ---------- interpolate_subcurve : boolean (optional) Toggle interpolation of subcurve. Default is True. """ for d in self.dependents.values(): d.interpolator = self.method(self.setpoints, d) if self.subcurve and interpolate_subcurve: self.source_setpoints.interpolator = self.method(self.setpoints, self.source_setpoints) def map_setpoints(self, setpoints, units="same"): """Map the curve onto new tune points using the curve's own interpolation method. Parameters ---------- setpoints : int or array The number of new points (between current limits) or the new points themselves. units : str (optional.) The input units if given as array. Default is same. Units of curve object are not changed by map_setpoints. """ # get new setpoints in input units if isinstance(setpoints, int): limits = self.get_limits(units) new_setpoints = np.linspace(limits[0], limits[1], setpoints) else: new_setpoints = setpoints # convert new setpoints to local units if units == "same": units = self.setpoints.units new_setpoints = np.sort(wt.units.converter(new_setpoints, units, self.setpoints.units)) # ensure that dependent interpolators agree with current dependent positions self.interpolate() # map own dependents new_dependents = {} for k, v in self.dependents.items(): positions = v(new_setpoints) new_dependent = Dependent( positions, k, v.units, v.differential, v.index ) # new dependent objects new_dependents.update({k: new_dependent}) # map source setpoints, subcurves if self.subcurve: new_source_setpoints = self.source_setpoints(new_setpoints) self.source_setpoints = Dependent( new_source_setpoints,, self.source_setpoints.units, index=self.source_setpoints.index ) # finish self.setpoints = Setpoints(new_setpoints,, self.setpoints.units) self.dependents = new_dependents for obj in self.dependents.values(): setattr(self,, obj) self.interpolate() def sort(self): order = self.setpoints[:].argsort() self.setpoints[:] = self.setpoints[order] try: self.source_setpoints[:] = self.source_setpoints[order] except (AttributeError, TypeError): pass # no subcurve setpoints for d in self.dependents.values(): d[:] = d[order] self.interpolate() def offset_by(self, dependent, amount): """Offset a dependent by some ammount. Parameters ---------- dependent : str The dependent name. amount : number The offset. See Also -------- offset_to """ # offset self.dependents[dependent].positions += amount self.dependents[dependent].interpolate() def offset_to(self, dependent, destination, setpoint, setpoint_units="same"): """Offset a dependent such that it evaluates to `destination` at `setpoint`. Parameters ---------- dependent : str The dependent name. amount : number The dependent position at setpoint after offseting. setpoint : number The setpoint at-which to set the dependent to amount. setpoint_units : str (optional) The setpoint units. Default is same. See Also -------- offset_by """ offset = destination - self[dependent](setpoint, setpoint_units) self.offset_by(dependent, offset) def plot(self, autosave=False, save_path="", title=None): """Plot the curve.""" # count number of subcurves subcurve_count = 0 total_dependent_count = len(self.dependents) current_curve = self all_curves = [self] while current_curve.subcurve: subcurve_count += 1 total_dependent_count += len(current_curve.subcurve.dependents) current_curve = current_curve.subcurve all_curves.append(current_curve) all_curves = all_curves[::-1] # prepare figure num_subplots = total_dependent_count + subcurve_count fig = plt.figure(figsize=(8, 2 * num_subplots)) axs = grd.GridSpec(num_subplots, 1, hspace=0) # assign subplot indicies ax_index = 0 ax_dictionary = {} lowest_ax_dictionary = {} for curve_index, curve in enumerate(all_curves): for dependent in curve.dependents: ax_dictionary[dependent] = axs[ax_index] lowest_ax_dictionary[] = axs[ax_index] ax_index += 1 if curve_index != len(all_curves): ax_index += 1 # add scatter for dependent in self.dependent_names: ax = plt.subplot(ax_dictionary[dependent]) xi = self.setpoints[:] yi = self(xi)[dependent] ax.scatter(xi, yi, c="k") limits = curve.get_limits() xi = np.linspace(limits[0], limits[1], 1000) yi = self(xi)[dependent] ax.plot(xi, yi, c="k") ax.set_ylabel(dependent) plt.xticks(self.setpoints[:]) plt.setp(ax.get_xticklabels(), visible=False) # get appropriate source setpoints source_setpoint_arrs = {} for curve_index, curve in enumerate(all_curves): current_curve = self current_arr = self.setpoints[:] for _ in range(len(all_curves) - curve_index - 1): current_arr = current_curve.get_source_setpoint(current_arr) current_curve = current_curve.subcurve source_setpoint_arrs[] = np.array(current_arr).flatten() # add labels for curve in all_curves: ax = plt.subplot(lowest_ax_dictionary[]) plt.setp(ax.get_xticklabels(), visible=True) ax.set_xlabel( + " setpoint ({})".format(self.setpoints.units)) if curve.interaction is not None: ax.set_xlabel(curve.interaction + " setpoint ({})".format(self.setpoints.units)) xtick_positions = self.setpoints[:] xtick_labels = [str(np.around(x, 1)) for x in source_setpoint_arrs[]] plt.xticks(xtick_positions, xtick_labels, rotation=45) # formatting details xmin = self.setpoints[:].min() - np.abs(self.setpoints[0] - self.setpoints[1]) xmax = self.setpoints[:].max() + np.abs(self.setpoints[0] - self.setpoints[1]) for ax in ax_dictionary.values(): ax = plt.subplot(ax) plt.xlim(xmin, xmax) plt.grid() ax.get_yaxis().get_major_formatter().set_useOffset(False) yticks = ax.yaxis.get_major_ticks() yticks[0].label1.set_visible(False) yticks[-1].label1.set_visible(False) # title if title is None: title = plt.suptitle(title) # save if autosave: save_path = pathlib.Path(save_path) image_path = save_path.with_suffix(".png") plt.savefig(image_path, transparent=True, dpi=300) plt.close(fig) @classmethod def read(cls, filepath, subcurve=None): filepath = pathlib.Path(filepath) headers = arr = np.genfromtxt(filepath).T names = headers["name"] units = headers.get("units", [None] * len(names)) differential = headers.get("differential", [False] * len(names)) dependents = [] for a, n, u, d in zip(arr[1:], names[1:], units[1:], differential[1:]): dependents.append(Dependent(a, n, units=u, differential=d)) kwargs = {} kwargs["interaction"] = headers.get("interaction", None) kwargs["kind"] = headers.get("kind", None) kwargs["method"] = builtins.get(headers.get("method", ""), Linear) kwargs["name"] = headers.get("curve name", filepath.stem) kwargs["fmt"] = headers.get("fmt", ["%.2f"] + ["%.5f"] * len(dependents)) setpoint_name = names[0] # Handle pre-attune release curves if units[0] is None: try: match = re.match(r"(.*)\((.*)\).*", names[0]) setpoint_name = match[1].strip() units[0] = match[2].strip() except TypeError: pass # No units setpoints = Setpoints(arr[0], setpoint_name, units[0]) if subcurve is not None: kwargs["subcurve"] = subcurve kwargs["source_setpoints"] = Dependent(setpoints[:], setpoint_name, units=units[0]) # finish curve = cls(setpoints, dependents, **kwargs) return curve def save(self, save_directory=None, plot=True, verbose=True, full=False): """Save the curve. Parameters ---------- save_directory : str (optional) The save directory. If not supplied, current working directory is used. plot : bool (optional) Toggle saving plot along with curve. Default is True. verbose : bool (optional) Toggle talkback. Default is True. full : bool (optional) Include all files (if curve is stored in multiple files) Returns ------- str The filepath of the saved curve. """ # get save directory if save_directory is None: save_directory = pathlib.Path() else: save_directory = pathlib.Path(save_directory) # array out_arr = np.zeros([len(self.dependents) + 1, len(self.setpoints)]) out_arr[0] = self.setpoints[:] out_arr[1:] = np.array([dependent.positions for dependent in self.dependents.values()]) # filename timestamp = wt.kit.TimeStamp() out_name ="-")[0] + "- " + timestamp.path out_path = (save_directory / out_name).with_suffix(".curve") # save headers = {} headers["curve name"] = headers["file created"] = timestamp.RFC3339 headers["interaction"] = self.interaction headers["kind"] = self.kind headers["method"] = self.method.__name__ headers["units"] = [self.setpoints.units] + [d.units for d in self.dependents.values()] headers["differential"] = [None] + [d.differential for d in self.dependents.values()] headers["name"] = [f"{}"] + [ for d in self.dependents.values()] tidy_headers.write(out_path, headers) with open(out_path, "at") as f: np.savetxt(f, out_arr.T, fmt=self.fmt, delimiter="\t") # save subcurve if self.subcurve: # plot if plot: image_path = out_path.with_suffix(".png") title = out_path.stem self.plot(autosave=True, save_path=image_path, title=title) # finish if verbose: print("curve saved at", out_path) return out_path