Source code for noggin.plotter

from math import ceil
import importlib
import time
from collections import OrderedDict, defaultdict, deque
from collections.abc import Sequence
from inspect import cleandoc
from itertools import product
from numbers import Integral, Real
from typing import Dict, Optional, Set, Tuple, Union
from warnings import warn

import numpy as np
from matplotlib.pyplot import Axes, Figure

from noggin.logger import LiveLogger, LiveMetric
from noggin.typing import Metrics, ValidColor

__all__ = ["LivePlot"]


def _check_valid_color(c: ValidColor) -> bool:
    """
    Checks if `c` is a valid color argument for matplotlib or `None`.
    Raises `ValueError` if `c` is not a valid color.

    Parameters
    ----------
    c : Union[str, Real, Sequence[Real], NoneType]

    Returns
    -------
    bool

    Raises
    ------
    ValueError"""
    from matplotlib.colors import is_color_like

    if c is not None and not is_color_like(c):
        raise ValueError("{} is not a valid matplotlib color".format(repr(c)))
    else:
        return True


[docs]class LivePlot(LiveLogger): """Records and plots batch-level and epoch-level summary statistics of the training and testing metrics of a model during a session. The rate at which the plot is updated is controlled by :obj:`~noggin.plotter.LivePlot.max_fraction_spent_plotting`. The maximum number of batches to be included in the plot is controlled by :obj:`~noggin.plotter.LivePlot.last_n_batches`. Notes ----- Live plotting is only supported for the 'nbAgg' backend (i.e. when the cell magic ``%matplotlib notebook`` is invoked in a jupyter notebook). """ @property def metrics(self) -> Tuple[str, ...]: """A tuple of all the metric names""" return self._metrics @property def metric_colors(self) -> Dict[str, Dict[str, ValidColor]]: """The color associated with each of the train/test and batch/epoch-level metrics. Returns ------- Dict[str, Dict[str, color-value]] {'<metric-name>' -> {'train'/'test' -> color-value}}""" out = defaultdict(dict) for k, v in self._train_colors.items(): out[k]["train"] = v for k, v in self._test_colors.items(): out[k]["test"] = v return dict(out) @metric_colors.setter def metric_colors(self, value: Dict[str, Union[ValidColor, Dict[str, ValidColor]]]): if not isinstance(value, dict): raise TypeError( "`metric_colors` must be a dictionary that maps:" "\nmetric-name -> valid-color" "\nor" "\nmetric-name -> 'train' -> valid-color" "\n 'test' -> valid-color" "\nGot: {}".format(value) ) for k, v in value.items(): if k not in self.metrics: continue if isinstance(v, dict): self._train_colors[k] = v.get("train") self._test_colors[k] = v.get("test") else: self._train_colors[k] = v sum(_check_valid_color(c) for c in self._train_colors.values()) sum(_check_valid_color(c) for c in self._test_colors.values()) @property def figsize(self) -> Optional[Tuple[float, float]]: """Returns the current size of the figure in inches. Parameters ---------- Optional[Tuple[float, float]] """ return self._pltkwargs.get("figsize") @figsize.setter def figsize(self, size: Tuple[float, float]): if ( not isinstance(size, Sequence) or len(size) != 2 or not all(isinstance(x, Real) and x > 0 for x in size) ): raise ValueError( f"`size` must be a length-2 sequence of " f"positive-valued numbers, got: {size}" ) size = tuple(size) if self.figsize != size: self._pltkwargs["figsize"] = size if self._fig is not None: self._fig.set_size_inches(size) @property def plot_objects(self) -> Union[Tuple[Figure, Axes], Tuple[Figure, np.ndarray]]: """ The figure-instance of the plot, and the axis-instance for each metric. Notes ----- Calling this method will initialize the plot window if it is not already rendered. Returns ------- Union[Tuple[Figure, Axes], Tuple[Figure, np.ndarray]] If more than one set of axes are present in the figure, an array of axes is returned instead.""" if self._fig is None: self._init_plot_window() if self._axes.size == 1: axis = self._axes.item() # type: Axes return self._fig, axis else: return self._fig, self._axes @property def max_fraction_spent_plotting(self) -> float: """The maximum fraction of time spent plotting. Parameters ---------- value : float A value in [0, 1]. A value of ``0.0`` turns live-plotting off. A value of ``1.0`` will result in the plot updating whenever a new measurement is recorded. Notes ----- The refresh rate for plotting will update dynamically such that:: mean_plot_time / (time_since_last_plot + mean_plot_time) does not exceed ``max_fraction_spent_plotting``.""" return self._max_fraction_spent_plotting @max_fraction_spent_plotting.setter def max_fraction_spent_plotting(self, value: float): if not isinstance(value, (int, float)): raise TypeError( "`max_fraction_spent_plotting` must be a " "floating point number in [0, 1], got {}".format(value) ) if not 0 <= value <= 1: raise ValueError( "`max_fraction_spent_plotting` must be a " "floating point number in [0, 1], got {}".format(value) ) self._max_fraction_spent_plotting = value @property def last_n_batches(self) -> int: """The maximum number of batches to be plotted at any given time. If ``None``, all data will be plotted. Parameters ---------- value : Union[int, None] """ return self._last_n_batches @last_n_batches.setter def last_n_batches(self, value: int): self._epoch_domain_lookup = dict(train=defaultdict(int), test=defaultdict(int)) # type: Dict[str, Dict[str, int]] if value is None: self._last_n_batches = None return if not isinstance(value, int): raise TypeError( "`last_n_batches` must be a positive integer, got {}".format(value) ) if value < 1: raise ValueError( "`last_n_batches` must be a positive integer, got {}".format(value) ) # Points to starting index for the epoch-domain of a # given metric's name; used to keep epoch plot within # "last-n-batches" plotted. # This must be reset each time `last_n_batches` is set self._last_n_batches = value
[docs] def __init__( self, metrics: Metrics, max_fraction_spent_plotting: float = 0.05, last_n_batches: Optional[int] = None, nrows: Optional[int] = None, ncols: int = 1, figsize: Optional[Tuple[int, int]] = None, ): """ Parameters ---------- metrics : Union[str, Sequence[str], Dict[str, valid-color], Dict[str, Dict['train'/'test', valid-color]]] The name, or sequence of names, of the metric(s) that will be plotted. ``metrics`` can also be a dictionary, specifying the colors used to plot the metrics. Two mappings are valid: - '<metric-name>' -> color-value (specifies train-metric color only) - '<metric-name>' -> {'train'/'test' : color-value} max_fraction_spent_plotting : float, optional (default=0.05) The maximum fraction of time spent plotting. The default value is ``0.05``, meaning that no more than 5% of processing time will be spent plotting, on average. last_n_batches : Optional[int] The maximum number of batches to be plotted at any given time. If ``None``, all data will be plotted. nrows : Optional[int] Number of rows of the subplot grid. Metrics are added in row-major order to fill the grid. ncols : int, optional, default: 1 Number of columns of the subplot grid. Metrics are added in row-major order to fill the grid. figsize : Optional[Sequence[float, float]] Specifies the width and height, respectively, of the figure.""" # type checking on inputs # initializes the batch and epoch numbers super().__init__() # import matplotlib and check backend self._pyplot = importlib.import_module("matplotlib.pyplot") _matplotlib = importlib.import_module("matplotlib") self._backend = _matplotlib.get_backend() self._liveplot = "nbAgg" in self._backend # plot-settings for batch and epoch data self._batch_ax = dict(ls="-", alpha=0.5) self._epoch_ax = dict(ls="-", marker="o", markersize=6, lw=3) self._legend = dict() # metric name -> matplotlib axis object self._axis_mapping = OrderedDict() # type: Dict[str, Axes] # plot objects self._fig = None # type: Optional[Figure] self._axes = None # type: Union[None, Axes, np.ndarray] # plotting logic self._plot_batch = True # type: bool self._last_plot_time = None # type: Optional[float] self._plot_time_queue = deque([]) # stores most recent plot-times (seconds) self._time_of_last_liveplot_attempt = None # type: Optional[float] self._draw_time = 0.0 # type: float # 'train/test' -> {metric-name -> batch-index of most-recent epoch} self._epoch_domain_lookup = dict(train=defaultdict(int), test=defaultdict(int)) # type: Dict[str, Dict[str, int]] self.last_n_batches = last_n_batches # used to warn users only once when they plot an unregistered metric self._unregistered_metrics = set() # type: Set[str] # stores most times between consecutive live-plot attempts (seconds) self._queue_size = 4 self.max_fraction_spent_plotting = max_fraction_spent_plotting # input parameters self._metrics = (metrics,) if isinstance(metrics, str) else tuple(metrics) if not len(self._metrics) == len(set(self._metrics)): from collections import Counter count = Counter(self._metrics) _items = [name for name, cnt in count.most_common() if cnt > 1] raise ValueError( "`metrics` must specify mutually-unique names. " "\n `{}` {} specified redundantly".format( ", ".join(_items), "was" if len(_items) == 1 else "were" ) ) if not self._metrics: raise ValueError("At least one metric must be specified") if any(not isinstance(i, str) for i in self._metrics): raise TypeError("`metrics` must be a string or a collection of strings") if nrows is None: nrows = 1 if not isinstance(nrows, Integral) or 1 > nrows: raise ValueError( "`nrows` must integer-valued and be at least 1. Got {}".format(nrows) ) if not isinstance(ncols, Integral) or 1 > ncols: raise ValueError( "`ncols` must integer-valued and be at least 1. Got {}".format(ncols) ) if len(self._metrics) > ncols * nrows: nrows = int(ceil(len(self._metrics) / ncols)) assert nrows * ncols >= len(self._metrics) self._pltkwargs = dict(nrows=nrows, ncols=ncols) if figsize is not None: self.figsize = figsize else: self._pltkwargs["figsize"] = None # color config self._train_colors = defaultdict(lambda: None) self._test_colors = defaultdict(lambda: None) if isinstance(metrics, dict): self.metric_colors = metrics if "nbAgg" not in self._backend and max_fraction_spent_plotting > 0.0: _inline_msg = """Live plotting is not supported when matplotlib uses the '{}' backend. Instead, use the 'nbAgg' backend. In a Jupyter notebook, this can be activated using the cell magic: %matplotlib notebook.""" warn(cleandoc(_inline_msg.format(self._backend)))
[docs] def to_dict(self): """Records the state of the plotter in a dictionary. This is the inverse of :func:`~noggin.plotter.LivePlot.from_dict` Returns ------- Dict[str, Any] Notes ----- To save your plotter, use this method to convert it to a dictionary and then pickle the dictionary. """ out = super().to_dict() out.update( dict( max_fraction_spent_plotting=self.max_fraction_spent_plotting, last_n_batches=self.last_n_batches, pltkwargs=self._pltkwargs, train_colors=dict(self._train_colors), test_colors=dict(self._test_colors), metric_names=self._metrics, ) ) return out
[docs] @classmethod def from_dict(cls, plotter_dict): """Records the state of the plotter in a dictionary. This is the inverse of :func:`~noggin.plotter.LivePlot.to_dict` Parameters ---------- plotter_dict : Dict[str, Any] The dictionary storing the state of the logger to be restored. Returns ------- noggin.LivePlot The restored plotter. Notes ----- This is a class-method, the syntax for invoking it is: >>> loaded_plotter = LivePlot.from_dict(plotter_dict) To restore your plot from the loaded plotter, call: >>> loaded_plotter.plot() """ new = cls( metrics=plotter_dict["metric_names"], max_fraction_spent_plotting=plotter_dict["max_fraction_spent_plotting"], last_n_batches=plotter_dict["last_n_batches"], ) new._train_metrics.update( (key, LiveMetric.from_dict(metric)) for key, metric in plotter_dict["train_metrics"].items() ) new._test_metrics.update( (key, LiveMetric.from_dict(metric)) for key, metric in plotter_dict["test_metrics"].items() ) for train_mode, stat_mode in product(["train", "test"], ["batch", "epoch"]): item = "num_{}_{}".format(train_mode, stat_mode) setattr(new, "_" + item, plotter_dict[item]) for attr in ("pltkwargs", "train_colors", "test_colors"): setattr(new, "_" + attr, plotter_dict[attr]) train_colors = defaultdict(lambda: None) test_colors = defaultdict(lambda: None) train_colors.update(new._train_colors) test_colors.update(new._test_colors) new._train_colors = train_colors new._test_colors = test_colors return new
def _filter_unregistered_metrics(self, metrics: Dict[str, Real]) -> Dict[str, Real]: """ Returns ------- Dict[str, Real] A dictionary containing only registered metric-names. Warns ----- UserWarning Unknown metric was logged """ unknown_metrics = set(metrics).difference(self._metrics) if unknown_metrics - self._unregistered_metrics: msg = "\nThe following metrics are not registered for live-plotting: " warn(msg + ", ".join(sorted(unknown_metrics - self._unregistered_metrics))) self._unregistered_metrics.update(unknown_metrics) return ( {k: v for k, v in metrics.items() if k in self._metrics} if unknown_metrics else metrics )
[docs] def set_train_batch( self, metrics: Dict[str, Real], batch_size: Integral, plot: bool = True ): """Record batch-level measurements for train-metrics, and (optionally) plot them. Parameters ---------- metrics : Dict[str, Real] Mapping of metric-name to value. Only those metrics that were registered when initializing LivePlot will be recorded. batch_size : Integral The number of samples in the batch used to produce the metrics. Used to weight the metrics to produce epoch-level statistics. plot : bool If True, plot the batch-metrics (adhering to the refresh rate)""" super().set_train_batch( self._filter_unregistered_metrics(metrics), batch_size=batch_size ) self._plot_batch = plot if self._plot_batch: self._do_liveplot()
[docs] def set_train_epoch(self): """Record and plot an epoch for the train-metrics. Computes epoch-level statistics based on the batches accumulated since the prior epoch. """ super().set_train_epoch() self._do_liveplot()
[docs] def set_test_batch(self, metrics: Dict[str, Real], batch_size: Integral): """Record batch-level measurements for test-metrics. Parameters ---------- metrics : Dict[str, Real] Mapping of metric-name to value. Only those metrics that were registered when initializing LivePlot will be recorded. batch_size : Integral The number of samples in the batch used to produce the metrics. Used to weight the metrics to produce epoch-level statistics.""" super().set_test_batch( self._filter_unregistered_metrics(metrics), batch_size=batch_size )
[docs] def set_test_epoch(self): """Record and plot an epoch for the test-metrics. Computes epoch-level statistics based on the batches accumulated since the prior epoch. """ super().set_test_epoch() self._do_liveplot()
def _init_plot_window(self): if self._fig is not None: return None self._fig, self._axes = self._pyplot.subplots(sharex=True, **self._pltkwargs) self._fig.tight_layout() self._pltkwargs["figsize"] = tuple(self._fig.get_size_inches()) if len(self._metrics) == 1: self._axes = np.array([self._axes]) # remove unused axes from plot grid axis_offset = self._axes.size - len(self._metrics) for i, ax in zip(range(axis_offset), self._axes.flat[::-1]): ax.remove() self._axis_mapping.update(zip(self._metrics, self._axes.flat)) for ax in self._axes.flat: ax.grid(True) # Add x-label to bottom-plot for each column for i in range(min(self._pltkwargs["ncols"], len(self._metrics))): self._axes.flat[-(i + 1 + axis_offset)].set_xlabel("Number of iterations") self._pyplot.ticklabel_format(style="sci", axis="x", scilimits=(0, 0))
[docs] def plot(self, plot_batches: bool = True): """Plot the logged data. This method can be used to 'force' a plot to be drawn, and should *not* be called repeatedly while logging data. Instead, one should invoke ``Liveplot.set_train_batch(plot=True)``, ``Liveplot.set_train_epoch``, and ``Liveplot.set_test_epoch``, which will adjust their plot-rates according to ``Liveplot.max_fraction_spent_plotting``. ``LivePlot.plot`` should be called at the end of a logging-loop to ensure that the logged data is plotted in its entirety. This can also be used to recreate a plot after deserializing a ``LivePlot`` instance. Parameters ---------- plot_batches : bool, optional (default=True) If ``True`` include batch-level data in plot.""" if not isinstance(plot_batches, bool): raise TypeError( "`plot_batch` must be `True` or `False`, got {}".format(plot_batches) ) self._init_plot_window() for key, livedata in self._train_metrics.items(): if livedata.batch_line is None: ax = self._axis_mapping[key] livedata.batch_line, = ax.plot( [], [], label="train", color=self._train_colors.get(key), **self._batch_ax, ) ax.set_title(key) ax.legend() if plot_batches and livedata.batch_line: n = ( self.last_n_batches if self.last_n_batches else len(livedata.batch_domain) ) livedata.batch_line.set_xdata(livedata.batch_domain[-n:]) livedata.batch_line.set_ydata(livedata.batch_data[-n:]) if livedata.epoch_data.size: livedata.batch_line.set_label( "train: {:.2e}".format(livedata.epoch_data[-1]) ) elif len(livedata.batch_line.get_xdata()): # clear batch-level plots livedata.batch_line.set_xdata(np.array([])) livedata.batch_line.set_ydata(np.array([])) # plot epoch-level train metrics for key, livedata in self._train_metrics.items(): if livedata.epoch_line is None: # initialize batch-level plot objects ax = self._axis_mapping[key] batch_color = self._train_metrics[key].batch_line.get_color() livedata.epoch_line, = ax.plot( [], [], color=batch_color, **self._epoch_ax ) ax.legend(**self._legend) if livedata.epoch_line is not None and livedata.batch_domain.size: self._update_epoch_domain( self.last_n_batches, batch_domain=livedata.batch_domain, epoch_domain_lookup=self._epoch_domain_lookup["train"], livedata=livedata, ) # plot epoch-level test metrics for key, livedata in self._test_metrics.items(): # initialize epoch-level plot objects if livedata.epoch_line is None: ax = self._axis_mapping[key] livedata.epoch_line, = ax.plot( [], [], label="test", color=self._test_colors.get(key), **self._epoch_ax, ) ax.set_title(key) ax.legend(**self._legend) if livedata.epoch_line is not None and livedata.batch_domain.size: if ( livedata.name in self._train_metrics and self._train_metrics[livedata.name].batch_domain.size ): batch_domain = self._train_metrics[livedata.name].batch_domain else: batch_domain = livedata.batch_domain self._update_epoch_domain( last_n_batches=self.last_n_batches, batch_domain=batch_domain, epoch_domain_lookup=self._epoch_domain_lookup["test"], livedata=livedata, ) if livedata.epoch_data.size: livedata.epoch_line.set_label( "test: " + "{:.2e}".format(livedata.epoch_data[-1]) ) s = time.time() self._update_text() self._resize() if self._liveplot and self._fig is not None: self._fig.canvas.draw() self._draw_time = time.time() - s
@staticmethod def _update_epoch_domain( last_n_batches: int, batch_domain: np.ndarray, epoch_domain_lookup: Dict[str, int], livedata: LiveMetric, ): """ Finds the oldest epoch batch-iteration within `last_n_batches` and sets the epoch-data such that it satisfies that bound. Parameters ---------- last_n_batches : int batch_domain : numpy.ndarray The training batch data epoch_domain_lookup : Dict[str, int] metric-name -> batch-iteration of previous earliest-plotted-epoch livedata : LiveMetric The metric being updated """ if last_n_batches: old_n = epoch_domain_lookup[livedata.name] n = np.searchsorted( livedata.epoch_domain[old_n:], batch_domain[-last_n_batches:][0] ) n += old_n else: n = 0 epoch_domain_lookup[livedata.name] = n livedata.epoch_line.set_xdata(livedata.epoch_domain[n:]) livedata.epoch_line.set_ydata(livedata.epoch_data[n:]) def _timed_plot(self, plot_batches: bool): plot_start_time = time.time() self.plot(plot_batches=plot_batches) self._last_plot_time = time.time() if len(self._plot_time_queue) == self._queue_size: self._plot_time_queue.popleft() self._plot_time_queue.append(self._last_plot_time - plot_start_time) def _do_liveplot(self): # enable active plotting upon first plot if self._last_plot_time is None: if self._liveplot: self._pyplot.ion() self._last_plot_time = time.time() if not self._liveplot: return self._time_of_last_liveplot_attempt = time.time() time_since_last_plot = ( self._time_of_last_liveplot_attempt - self._last_plot_time ) mean_plot_time = ( sum(self._plot_time_queue) / len(self._plot_time_queue) if self._plot_time_queue else 0.0 ) if self.max_fraction_spent_plotting == 1.0 or ( time_since_last_plot and mean_plot_time / (time_since_last_plot + mean_plot_time) < self.max_fraction_spent_plotting ): self._timed_plot(plot_batches=self._plot_batch) # exclude plot time self._time_of_last_liveplot_attempt = time.time() def _resize(self): if self._axes is None: # pragma: no cover return for ax in self._axes.flat: ax.relim() ax.autoscale_view() def _update_text(self): for ax in self._axis_mapping.values(): ax.legend()
[docs] def show(self): # pragma: no cover """ Calls ``matplotlib.pyplot.show()``. For visualizing a static-plot""" if not self._liveplot: self._pyplot.show()